Skip to content

Commit bdded09

Browse files
committed
Fallback to regular tfrecord iterator for datasets without all the processed features
Signed-off-by: Kunjan <kunjanp@google.com>
1 parent f418a0c commit bdded09

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import tensorflow.experimental.numpy as tnp
2020
from datasets import load_dataset, load_from_disk
2121

22-
from maxdiffusion import multihost_dataloading
22+
from maxdiffusion import multihost_dataloading, max_logging
2323

2424
AUTOTUNE = tf.data.AUTOTUNE
2525

@@ -73,7 +73,6 @@ def make_tf_iterator(
7373
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7474
return train_iter
7575

76-
7776
def make_cached_tfrecord_iterator(
7877
config,
7978
dataloading_host_index,
@@ -105,6 +104,7 @@ def prepare_sample(features):
105104

106105
# This pipeline reads the sharded files and applies the parsing and preparation.
107106
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
107+
108108
train_ds = (
109109
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
110110
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
@@ -133,8 +133,13 @@ def make_tfrecord_iterator(
133133
check out preparation script
134134
maxdiffusion/pedagogical_examples/to_tfrecords.py
135135
"""
136-
if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location):
136+
137+
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
138+
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
139+
# Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
140+
if config.cache_latents_text_encoder_outputs and os.path.isdir(config.dataset_save_location) and config.get("load_tfrecord_cached", False):
137141
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
142+
138143
feature_description = {
139144
"moments": tf.io.FixedLenFeature([], tf.string),
140145
"clip_embeddings": tf.io.FixedLenFeature([], tf.string),

0 commit comments

Comments
 (0)