Skip to content

Commit 727fdcb

Browse files
coolkphx89
authored andcommitted
Rebase, Optimize batch loading and metrics writing, replace PositionalSharding with NamedSharding (#186)
* fix profiling * Use torch cpu, async write to tensorboard, script to convert latents to tfrecord, batch iterator for tfrecord cached, namedsharding instead of positional sharding Signed-off-by: Kunjan <kunjanp@google.com> * Replace positional sharding with named sharding Signed-off-by: Kunjan <kunjanp@google.com> * Formatting Signed-off-by: Kunjan <kunjanp@google.com> * Formatting Signed-off-by: Kunjan <kunjanp@google.com> * Fallback to regular tfrecord iterator for datasets without all the processed features Signed-off-by: Kunjan <kunjanp@google.com> * README update --------- Signed-off-by: Kunjan <kunjanp@google.com>
1 parent ce72ef9 commit 727fdcb

1 file changed

Lines changed: 49 additions & 0 deletions

File tree

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,55 @@ def prepare_sample(features):
126126
return train_iter
127127

128128

129+
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
130+
def make_tfrecord_iterator(
131+
config,
132+
dataloading_host_index,
133+
dataloading_host_count,
134+
mesh,
135+
global_batch_size,
136+
):
137+
"""
138+
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
139+
latents, input_ids, prompt_embeds, and text_embeds.
140+
"""
141+
feature_description = {
142+
"pixel_values": tf.io.FixedLenFeature([], tf.string),
143+
"input_ids": tf.io.FixedLenFeature([], tf.string),
144+
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
145+
"text_embeds": tf.io.FixedLenFeature([], tf.string),
146+
}
147+
148+
def _parse_tfrecord_fn(example):
149+
return tf.io.parse_single_example(example, feature_description)
150+
151+
def prepare_sample(features):
152+
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
153+
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
154+
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
155+
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)
156+
157+
return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}
158+
159+
# This pipeline reads the sharded files and applies the parsing and preparation.
160+
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
161+
162+
train_ds = (
163+
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
164+
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
165+
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
166+
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
167+
.shuffle(global_batch_size * 10)
168+
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
169+
.repeat(-1)
170+
.prefetch(AUTOTUNE)
171+
)
172+
173+
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
174+
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
175+
return train_iter
176+
177+
129178
def make_cached_tfrecord_iterator(
130179
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
131180
):

0 commit comments

Comments
 (0)