Skip to content

Commit afc8882

Browse files
author
Juan Acevedo
committed
training pipeline with image dataset.
1 parent 1536f42 commit afc8882

7 files changed

Lines changed: 162 additions & 68 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,15 @@ ici_tensor_parallelism: 1
141141
# Replace with dataset path or train_data_dir. One has to be set.
142142
dataset_name: 'diffusers/pokemon-gpt4-captions'
143143
train_split: 'train'
144-
dataset_type: 'tf'
144+
dataset_type: 'tfrecord'
145145
cache_latents_text_encoder_outputs: True
146146
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
147147
# only apply to small dataset that fits in memory
148148
# prepare image latents and text encoder outputs
149149
# Reduce memory consumption and reduce step time during training
150150
# transformed dataset is saved at dataset_save_location
151-
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
151+
dataset_save_location: ''
152+
load_tfrecord_cached: True
152153
train_data_dir: ''
153154
dataset_config_name: ''
154155
jax_cache_dir: ''

src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,25 +112,26 @@ def generate_dataset(config, pipeline):
112112
for i in range(0, len(ds), batch_size):
113113
rng, new_rng = jax.random.split(rng)
114114
text = ds[i:i+batch_size]['text']
115-
video = ds[i:i+batch_size]['image']
115+
videos = ds[i:i+batch_size]['image']
116116

117-
video = [np.expand_dims(np.array(i), axis=0) for i in video]
118-
video = video_processor.preprocess_video(video, height=config.height, width=config.width)
119-
video = jnp.array(np.array(video), dtype=config.weights_dtype)
117+
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width) for video in videos]
118+
video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
120119
with mesh:
121120
latents = p_vae_encode(video=video, rng=new_rng)
121+
latents = jnp.transpose(latents, (0, 4, 1, 2, 3))
122122
encoder_hidden_states = text_encode(pipeline, text)
123-
example = create_example(latents, encoder_hidden_states)
124-
writer.write(example)
125-
shard_record_count += batch_size
126-
global_record_count += batch_size
123+
for latent, encoder_hidden_state in zip(latents, encoder_hidden_states):
124+
writer.write(create_example(latent, encoder_hidden_state))
125+
shard_record_count += 1
126+
global_record_count += 1
127+
127128
if shard_record_count >= no_records_per_shard:
128129
writer.close()
130+
tf_rec_num +=1
129131
writer = tf.io.TFRecordWriter(
130132
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
131133
)
132134
shard_record_count = 0
133-
tf_rec_num +=1
134135

135136

136137

src/maxdiffusion/generate_wan.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
from maxdiffusion.utils import export_to_video
2222

2323

24-
def run(config):
24+
def run(config, pipeline=None, filename_prefix=''):
2525
print("seed: ", config.seed)
26-
pipeline = WanPipeline.from_pretrained(config)
26+
if pipeline is None:
27+
pipeline = WanPipeline.from_pretrained(config)
2728
s0 = time.perf_counter()
2829

2930
# Skip layer guidance
@@ -59,7 +60,7 @@ def run(config):
5960

6061
print("compile time: ", (time.perf_counter() - s0))
6162
for i in range(len(videos)):
62-
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
63+
export_to_video(videos[i], f"{filename_prefix}wan_output_{config.seed}_{i}.mp4", fps=config.fps)
6364
s0 = time.perf_counter()
6465
videos = pipeline(
6566
prompt=prompt,

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,37 +79,25 @@ def make_cached_tfrecord_iterator(
7979
dataloading_host_count,
8080
mesh,
8181
global_batch_size,
82+
feature_description,
83+
prepare_sample_fn
8284
):
8385
"""
8486
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
8587
latents, input_ids, prompt_embeds, and text_embeds.
8688
"""
87-
feature_description = {
88-
"pixel_values": tf.io.FixedLenFeature([], tf.string),
89-
"input_ids": tf.io.FixedLenFeature([], tf.string),
90-
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
91-
"text_embeds": tf.io.FixedLenFeature([], tf.string),
92-
}
9389

9490
def _parse_tfrecord_fn(example):
9591
return tf.io.parse_single_example(example, feature_description)
9692

97-
def prepare_sample(features):
98-
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
99-
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
100-
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
101-
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)
102-
103-
return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}
104-
10593
# This pipeline reads the sharded files and applies the parsing and preparation.
10694
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
10795

10896
train_ds = (
10997
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
11098
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
11199
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
112-
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
100+
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
113101
.shuffle(global_batch_size * 10)
114102
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
115103
.repeat(-1)
@@ -128,6 +116,8 @@ def make_tfrecord_iterator(
128116
dataloading_host_count,
129117
mesh,
130118
global_batch_size,
119+
feature_description,
120+
prepare_sample_fn
131121
):
132122
"""Iterator for TFRecord format. For Laion dataset,
133123
check out preparation script
@@ -136,12 +126,20 @@ def make_tfrecord_iterator(
136126

137127
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
138128
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
139-
# Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
129+
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
140130
if (config.cache_latents_text_encoder_outputs
141131
and os.path.isdir(config.dataset_save_location)
142132
and 'load_tfrecord_cached'in config.get_keys()
143133
and config.load_tfrecord_cached):
144-
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
134+
return make_cached_tfrecord_iterator(
135+
config,
136+
dataloading_host_index,
137+
dataloading_host_count,
138+
mesh,
139+
global_batch_size,
140+
feature_description,
141+
prepare_sample_fn
142+
)
145143

146144
feature_description = {
147145
"moments": tf.io.FixedLenFeature([], tf.string),

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,18 @@ def make_data_iterator(
5050
global_batch_size,
5151
tokenize_fn=None,
5252
image_transforms_fn=None,
53+
feature_description=None,
54+
prepare_sample_fn=None
5355
):
5456
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""
57+
58+
if config.dataset_type == "hf" or config.dataset_type == "tf":
59+
if tokenize_fn is None or image_transforms_fn is None:
60+
raise ValueError(f"dataset type {config.dataset_type} needs to pass a tokenize_fn and image_transforms_fn")
61+
62+
if config.dataset_type == "tfrecord" and config.cache_latents_text_encoder_outputs and feature_description is None or prepare_sample_fn is None:
63+
raise ValueError(f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True.")
64+
5565
if config.dataset_type == "hf":
5666
return _hf_data_processing.make_hf_streaming_iterator(
5767
config,
@@ -87,6 +97,8 @@ def make_data_iterator(
8797
dataloading_host_count,
8898
mesh,
8999
global_batch_size,
100+
feature_description,
101+
prepare_sample_fn
90102
)
91103
else:
92104
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"

src/maxdiffusion/trainers/sdxl_trainer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import threading
2121
import time
2222
import numpy as np
23+
import tensorflow as tf
2324
import jax
2425
import jax.numpy as jnp
2526
from jax.sharding import PartitionSpec as P
@@ -140,6 +141,21 @@ def load_dataset(self, pipeline, params, train_states):
140141
p_vae_apply=p_vae_apply,
141142
)
142143

144+
feature_description = {
145+
"pixel_values": tf.io.FixedLenFeature([], tf.string),
146+
"input_ids": tf.io.FixedLenFeature([], tf.string),
147+
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
148+
"text_embeds": tf.io.FixedLenFeature([], tf.string),
149+
}
150+
151+
def prepare_sample(features):
152+
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
153+
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
154+
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
155+
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)
156+
157+
return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}
158+
143159
data_iterator = make_data_iterator(
144160
config,
145161
jax.process_index(),
@@ -148,6 +164,8 @@ def load_dataset(self, pipeline, params, train_states):
148164
total_train_batch_size,
149165
tokenize_fn=tokenize_fn,
150166
image_transforms_fn=image_transforms_fn,
167+
feature_description=feature_description,
168+
prepare_sample_fn=prepare_sample
151169
)
152170

153171
return data_iterator

0 commit comments

Comments
 (0)