Skip to content

Commit d392bf4

Browse files
coolkphx89
authored andcommitted
Optimize batch loading and metrics writing, replace PositionalSharding with NamedSharding (#186)
* fix profiling * Use torch cpu, async write to tensorboard, script to convert latents to tfrecord, batch iterator for tfrecord cached, namedsharding instead of positional sharding Signed-off-by: Kunjan <kunjanp@google.com> * Replace positional sharding with named sharding Signed-off-by: Kunjan <kunjanp@google.com> * Formatting Signed-off-by: Kunjan <kunjanp@google.com> * Formatting Signed-off-by: Kunjan <kunjanp@google.com> * Fallback to regular tfrecord iterator for datasets without all the processed features Signed-off-by: Kunjan <kunjanp@google.com> * README update --------- Signed-off-by: Kunjan <kunjanp@google.com>
1 parent eaef265 commit d392bf4

2 files changed

Lines changed: 56 additions & 15 deletions

File tree

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,53 @@ def make_tf_iterator(
7878
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7979
return train_iter
8080

81+
def make_cached_tfrecord_iterator(
82+
config,
83+
dataloading_host_index,
84+
dataloading_host_count,
85+
mesh,
86+
global_batch_size,
87+
):
88+
"""
89+
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
90+
latents, input_ids, prompt_embeds, and text_embeds.
91+
"""
92+
feature_description = {
93+
"pixel_values": tf.io.FixedLenFeature([], tf.string),
94+
"input_ids": tf.io.FixedLenFeature([], tf.string),
95+
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
96+
"text_embeds": tf.io.FixedLenFeature([], tf.string),
97+
}
98+
99+
def _parse_tfrecord_fn(example):
100+
return tf.io.parse_single_example(example, feature_description)
101+
102+
def prepare_sample(features):
103+
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
104+
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
105+
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
106+
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)
107+
108+
return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}
109+
110+
# This pipeline reads the sharded files and applies the parsing and preparation.
111+
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
112+
113+
train_ds = (
114+
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
115+
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
116+
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
117+
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
118+
.shuffle(global_batch_size * 10)
119+
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
120+
.repeat(-1)
121+
.prefetch(AUTOTUNE)
122+
)
123+
124+
# This wraps the tf.data.Dataset for use in the multi-host JAX environment.
125+
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
126+
return train_iter
127+
81128

82129
def make_cached_tfrecord_iterator(
83130
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
@@ -120,22 +167,12 @@ def make_tfrecord_iterator(
120167

121168
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
122169
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
123-
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
124-
if (
125-
config.cache_latents_text_encoder_outputs
170+
# Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
171+
if (config.cache_latents_text_encoder_outputs
126172
and os.path.isdir(config.dataset_save_location)
127-
and "load_tfrecord_cached" in config.get_keys()
128-
and config.load_tfrecord_cached
129-
):
130-
return make_cached_tfrecord_iterator(
131-
config,
132-
dataloading_host_index,
133-
dataloading_host_count,
134-
mesh,
135-
global_batch_size,
136-
feature_description,
137-
prepare_sample_fn,
138-
)
173+
and 'load_tfrecord_cached'in config.get_keys()
174+
and config.load_tfrecord_cached):
175+
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
139176

140177
feature_description = {
141178
"moments": tf.io.FixedLenFeature([], tf.string),

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,11 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
198198
# This replaces random params with the model.
199199
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
200200
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
201+
<<<<<<< HEAD
201202
params = jax.device_put(params, NamedSharding(mesh, P()))
203+
=======
204+
params = jax.device_put(params, NamedSharding(devices_array, P()))
205+
>>>>>>> f344ab0 (Optimize batch loading and metrics writing, replace PositionalSharding with NamedSharding (#186))
202206
wan_vae = nnx.merge(graphdef, params)
203207
p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
204208
# Shard

0 commit comments

Comments
 (0)