Skip to content

Commit b4b5a45

Browse files
entrpnJuan Acevedo
andauthored
Wan training cont (#191)
* flow match scheduler + data to tf records * training pipeline with image dataset. --------- Co-authored-by: Juan Acevedo <juancevedo@google.com>
1 parent e2cb67f commit b4b5a45

12 files changed

Lines changed: 678 additions & 93 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,15 @@ ici_tensor_parallelism: 1
141141
# Replace with dataset path or train_data_dir. One has to be set.
142142
dataset_name: 'diffusers/pokemon-gpt4-captions'
143143
train_split: 'train'
144-
dataset_type: 'tf'
144+
dataset_type: 'tfrecord'
145145
cache_latents_text_encoder_outputs: True
146146
# cache_latents_text_encoder_outputs only apply to dataset_type="tf",
147147
# only apply to small dataset that fits in memory
148148
# prepare image latents and text encoder outputs
149149
# Reduce memory consumption and reduce step time during training
150150
# transformed dataset is saved at dataset_save_location
151-
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
151+
dataset_save_location: ''
152+
load_tfrecord_cached: True
152153
train_data_dir: ''
153154
dataset_config_name: ''
154155
jax_cache_dir: ''
@@ -185,6 +186,10 @@ per_device_batch_size: 1
185186
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
186187
global_batch_size: 0
187188

189+
# For creating tfrecords from dataset
190+
tfrecords_dir: ''
191+
no_records_per_shard: 0
192+
188193
warmup_steps_fraction: 0.1
189194
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
190195

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""
18+
Prepare tfrecords with latents and text embeddings preprocessed.
19+
1. Download the dataset
20+
"""
21+
22+
import os
23+
import functools
24+
from absl import app
25+
from typing import Sequence, Union, List
26+
from datasets import load_dataset
27+
import numpy as np
28+
import jax
29+
import jax.numpy as jnp
30+
from jax.sharding import Mesh
31+
from maxdiffusion import pyconfig, max_utils
32+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
33+
from maxdiffusion.video_processor import VideoProcessor
34+
35+
import tensorflow as tf
36+
37+
38+
def image_feature(value):
39+
"""Returns a bytes_list from a string / byte."""
40+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()]))
41+
42+
43+
def bytes_feature(value):
44+
"""Returns a bytes_list from a string / byte."""
45+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))
46+
47+
48+
def float_feature(value):
49+
"""Returns a float_list from a float / double."""
50+
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
51+
52+
53+
def int64_feature(value):
54+
"""Returns an int64_list from a bool / enum / int / uint."""
55+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
56+
57+
58+
def float_feature_list(value):
59+
"""Returns a list of float_list from a float / double."""
60+
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
61+
62+
63+
def create_example(latent, hidden_states):
64+
latent = tf.io.serialize_tensor(latent)
65+
hidden_states = tf.io.serialize_tensor(hidden_states)
66+
feature = {
67+
"latents": bytes_feature(latent),
68+
"encoder_hidden_states": bytes_feature(hidden_states),
69+
}
70+
example = tf.train.Example(features=tf.train.Features(feature=feature))
71+
return example.SerializeToString()
72+
73+
74+
def text_encode(pipeline, prompt: Union[str, List[str]]):
75+
encoder_hidden_states = pipeline._get_t5_prompt_embeds(prompt)
76+
encoder_hidden_states = encoder_hidden_states.detach().numpy()
77+
return encoder_hidden_states
78+
79+
80+
def vae_encode(video, rng, vae, vae_cache):
81+
latent = vae.encode(video, feat_cache=vae_cache)
82+
latent = latent.latent_dist.sample(rng)
83+
return latent
84+
85+
86+
def generate_dataset(config, pipeline):
87+
88+
tfrecords_dir = config.tfrecords_dir
89+
if not os.path.exists(tfrecords_dir):
90+
os.makedirs(tfrecords_dir)
91+
92+
tf_rec_num = 0
93+
no_records_per_shard = config.no_records_per_shard
94+
global_record_count = 0
95+
writer = tf.io.TFRecordWriter(
96+
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
97+
)
98+
shard_record_count = 0
99+
100+
# create mesh
101+
devices_array = max_utils.create_device_mesh(config)
102+
mesh = Mesh(devices_array, config.mesh_axes)
103+
rng = jax.random.key(config.seed)
104+
105+
vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample)
106+
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)
107+
108+
# jit vae fun.
109+
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))
110+
111+
# Load dataset
112+
ds = load_dataset(config.dataset_name, split="train")
113+
ds = ds.shuffle(seed=config.seed)
114+
ds = ds.select_columns([config.caption_column, config.image_column])
115+
batch_size = 10
116+
for i in range(0, len(ds), batch_size):
117+
rng, new_rng = jax.random.split(rng)
118+
text = ds[i : i + batch_size]["text"]
119+
videos = ds[i : i + batch_size]["image"]
120+
121+
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width) for video in videos]
122+
video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
123+
with mesh:
124+
latents = p_vae_encode(video=video, rng=new_rng)
125+
latents = jnp.transpose(latents, (0, 4, 1, 2, 3))
126+
encoder_hidden_states = text_encode(pipeline, text)
127+
for latent, encoder_hidden_state in zip(latents, encoder_hidden_states):
128+
writer.write(create_example(latent, encoder_hidden_state))
129+
shard_record_count += 1
130+
global_record_count += 1
131+
132+
if shard_record_count >= no_records_per_shard:
133+
writer.close()
134+
tf_rec_num += 1
135+
writer = tf.io.TFRecordWriter(
136+
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
137+
)
138+
shard_record_count = 0
139+
140+
141+
def run(config):
142+
pipeline = WanPipeline.from_pretrained(config, load_transformer=False)
143+
# Don't need the transformer for preprocessing.
144+
generate_dataset(config, pipeline)
145+
146+
147+
def main(argv: Sequence[str]) -> None:
148+
pyconfig.initialize(argv)
149+
run(pyconfig.config)
150+
151+
152+
if __name__ == "__main__":
153+
app.run(main)

src/maxdiffusion/generate_wan.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
from maxdiffusion.utils import export_to_video
2222

2323

24-
def run(config):
24+
def run(config, pipeline=None, filename_prefix=""):
2525
print("seed: ", config.seed)
26-
pipeline = WanPipeline.from_pretrained(config)
26+
if pipeline is None:
27+
pipeline = WanPipeline.from_pretrained(config)
2728
s0 = time.perf_counter()
2829

2930
# Skip layer guidance
@@ -59,7 +60,7 @@ def run(config):
5960

6061
print("compile time: ", (time.perf_counter() - s0))
6162
for i in range(len(videos)):
62-
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
63+
export_to_video(videos[i], f"{filename_prefix}wan_output_{config.seed}_{i}.mp4", fps=config.fps)
6364
s0 = time.perf_counter()
6465
videos = pipeline(
6566
prompt=prompt,

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -73,43 +73,26 @@ def make_tf_iterator(
7373
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7474
return train_iter
7575

76+
7677
def make_cached_tfrecord_iterator(
77-
config,
78-
dataloading_host_index,
79-
dataloading_host_count,
80-
mesh,
81-
global_batch_size,
78+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
8279
):
8380
"""
8481
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
8582
latents, input_ids, prompt_embeds, and text_embeds.
8683
"""
87-
feature_description = {
88-
"pixel_values": tf.io.FixedLenFeature([], tf.string),
89-
"input_ids": tf.io.FixedLenFeature([], tf.string),
90-
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
91-
"text_embeds": tf.io.FixedLenFeature([], tf.string),
92-
}
9384

9485
def _parse_tfrecord_fn(example):
9586
return tf.io.parse_single_example(example, feature_description)
9687

97-
def prepare_sample(features):
98-
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
99-
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
100-
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
101-
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)
102-
103-
return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}
104-
10588
# This pipeline reads the sharded files and applies the parsing and preparation.
10689
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
10790

10891
train_ds = (
10992
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
11093
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
11194
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
112-
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
95+
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
11396
.shuffle(global_batch_size * 10)
11497
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
11598
.repeat(-1)
@@ -123,11 +106,7 @@ def prepare_sample(features):
123106

124107
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
125108
def make_tfrecord_iterator(
126-
config,
127-
dataloading_host_index,
128-
dataloading_host_count,
129-
mesh,
130-
global_batch_size,
109+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
131110
):
132111
"""Iterator for TFRecord format. For Laion dataset,
133112
check out preparation script
@@ -136,12 +115,22 @@ def make_tfrecord_iterator(
136115

137116
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
138117
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
139-
# Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
140-
if (config.cache_latents_text_encoder_outputs
118+
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
119+
if (
120+
config.cache_latents_text_encoder_outputs
141121
and os.path.isdir(config.dataset_save_location)
142-
and 'load_tfrecord_cached'in config.get_keys()
143-
and config.load_tfrecord_cached):
144-
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
122+
and "load_tfrecord_cached" in config.get_keys()
123+
and config.load_tfrecord_cached
124+
):
125+
return make_cached_tfrecord_iterator(
126+
config,
127+
dataloading_host_index,
128+
dataloading_host_count,
129+
mesh,
130+
global_batch_size,
131+
feature_description,
132+
prepare_sample_fn,
133+
)
145134

146135
feature_description = {
147136
"moments": tf.io.FixedLenFeature([], tf.string),

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,25 @@ def make_data_iterator(
5050
global_batch_size,
5151
tokenize_fn=None,
5252
image_transforms_fn=None,
53+
feature_description=None,
54+
prepare_sample_fn=None,
5355
):
5456
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""
57+
58+
if config.dataset_type == "hf" or config.dataset_type == "tf":
59+
if tokenize_fn is None or image_transforms_fn is None:
60+
raise ValueError(f"dataset type {config.dataset_type} needs to pass a tokenize_fn and image_transforms_fn")
61+
62+
if (
63+
config.dataset_type == "tfrecord"
64+
and config.cache_latents_text_encoder_outputs
65+
and feature_description is None
66+
and prepare_sample_fn is None
67+
):
68+
raise ValueError(
69+
f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True."
70+
)
71+
5572
if config.dataset_type == "hf":
5673
return _hf_data_processing.make_hf_streaming_iterator(
5774
config,
@@ -87,6 +104,8 @@ def make_data_iterator(
87104
dataloading_host_count,
88105
mesh,
89106
global_batch_size,
107+
feature_description,
108+
prepare_sample_fn,
90109
)
91110
else:
92111
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
def basic_clean(text):
4242
if is_ftfy_available():
4343
import ftfy
44+
4445
text = ftfy.fix_text(text)
4546
text = html.unescape(html.unescape(text))
4647
return text.strip()
@@ -221,7 +222,7 @@ def load_scheduler(cls, config):
221222
return scheduler, scheduler_state
222223

223224
@classmethod
224-
def from_pretrained(cls, config: HyperParameters, vae_only=False):
225+
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
225226
devices_array = max_utils.create_device_mesh(config)
226227
mesh = Mesh(devices_array, config.mesh_axes)
227228
rng = jax.random.key(config.seed)
@@ -232,8 +233,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False):
232233
scheduler_state = None
233234
text_encoder = None
234235
if not vae_only:
235-
with mesh:
236-
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
236+
if load_transformer:
237+
with mesh:
238+
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
237239

238240
text_encoder = cls.load_text_encoder(config=config)
239241
tokenizer = cls.load_tokenizer(config=config)
@@ -397,7 +399,7 @@ def __call__(
397399
num_channels_latents=num_channel_latents,
398400
)
399401

400-
data_sharding = NamedSharding(self.devices_array, P())
402+
data_sharding = NamedSharding(self.mesh, P())
401403
if len(prompt) % jax.device_count() == 0:
402404
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))
403405

src/maxdiffusion/schedulers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
4444
_import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"]
4545
_import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"]
46-
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
46+
_import_structure["scheduling_flow_match_flax"] = ["FlaxFlowMatchScheduler"]
4747
_import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"]
4848
_import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"]
4949
_import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"]
@@ -70,6 +70,7 @@
7070
from .scheduling_ddpm_flax import FlaxDDPMScheduler
7171
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
7272
from .scheduling_euler_discrete_flax import FlaxEulerDiscreteScheduler
73+
from .scheduling_flow_match_flax import FlowMatchScheduler
7374
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
7475
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
7576
from .scheduling_pndm_flax import FlaxPNDMScheduler

0 commit comments

Comments
 (0)