Skip to content

Commit 1536f42

Browse files
author
Juan Acevedo
committed
flow match scheduler + data to tf records
1 parent e2cb67f commit 1536f42

7 files changed

Lines changed: 496 additions & 23 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ per_device_batch_size: 1
185185
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
186186
global_batch_size: 0
187187

188+
# For creating tfrecords from dataset
189+
tfrecords_dir: ''
190+
no_records_per_shard: 0
191+
188192
warmup_steps_fraction: 0.1
189193
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
190194

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""
18+
Prepare tfrecords with latents and text embeddings preprocessed.
19+
1. Download the dataset
20+
"""
21+
22+
import os
23+
import functools
24+
from absl import app
25+
from typing import Sequence, Union, List
26+
from datasets import load_dataset
27+
import numpy as np
28+
import jax
29+
import jax.numpy as jnp
30+
from jax.sharding import Mesh
31+
from maxdiffusion import pyconfig, max_utils
32+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
33+
from maxdiffusion.video_processor import VideoProcessor
34+
35+
import tensorflow as tf
36+
37+
def image_feature(value):
38+
"""Returns a bytes_list from a string / byte."""
39+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()]))
40+
41+
42+
def bytes_feature(value):
43+
"""Returns a bytes_list from a string / byte."""
44+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))
45+
46+
47+
def float_feature(value):
48+
"""Returns a float_list from a float / double."""
49+
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
50+
51+
52+
def int64_feature(value):
53+
"""Returns an int64_list from a bool / enum / int / uint."""
54+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
55+
56+
57+
def float_feature_list(value):
58+
"""Returns a list of float_list from a float / double."""
59+
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
60+
61+
def create_example(latent, hidden_states):
62+
latent = tf.io.serialize_tensor(latent)
63+
hidden_states = tf.io.serialize_tensor(hidden_states)
64+
feature = {
65+
"latents": bytes_feature(latent),
66+
"encoder_hidden_states": bytes_feature(hidden_states),
67+
}
68+
example = tf.train.Example(features=tf.train.Features(feature=feature))
69+
return example.SerializeToString()
70+
71+
72+
def text_encode(pipeline, prompt: Union[str, List[str]]):
73+
encoder_hidden_states = pipeline._get_t5_prompt_embeds(prompt)
74+
encoder_hidden_states = encoder_hidden_states.detach().numpy()
75+
return encoder_hidden_states
76+
77+
def vae_encode(video, rng, vae, vae_cache):
78+
latent = vae.encode(video, feat_cache=vae_cache)
79+
latent = latent.latent_dist.sample(rng)
80+
return latent
81+
82+
def generate_dataset(config, pipeline):
83+
84+
tfrecords_dir = config.tfrecords_dir
85+
if not os.path.exists(tfrecords_dir):
86+
os.makedirs(tfrecords_dir)
87+
88+
tf_rec_num = 0
89+
no_records_per_shard = config.no_records_per_shard
90+
global_record_count = 0
91+
writer = tf.io.TFRecordWriter(
92+
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
93+
)
94+
shard_record_count = 0
95+
96+
# create mesh
97+
devices_array = max_utils.create_device_mesh(config)
98+
mesh = Mesh(devices_array, config.mesh_axes)
99+
rng = jax.random.key(config.seed)
100+
101+
vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample)
102+
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)
103+
104+
# jit vae fun.
105+
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))
106+
107+
# Load dataset
108+
ds = load_dataset(config.dataset_name, split='train')
109+
ds = ds.shuffle(seed=config.seed)
110+
ds = ds.select_columns([config.caption_column, config.image_column])
111+
batch_size = 10
112+
for i in range(0, len(ds), batch_size):
113+
rng, new_rng = jax.random.split(rng)
114+
text = ds[i:i+batch_size]['text']
115+
video = ds[i:i+batch_size]['image']
116+
117+
video = [np.expand_dims(np.array(i), axis=0) for i in video]
118+
video = video_processor.preprocess_video(video, height=config.height, width=config.width)
119+
video = jnp.array(np.array(video), dtype=config.weights_dtype)
120+
with mesh:
121+
latents = p_vae_encode(video=video, rng=new_rng)
122+
encoder_hidden_states = text_encode(pipeline, text)
123+
example = create_example(latents, encoder_hidden_states)
124+
writer.write(example)
125+
shard_record_count += batch_size
126+
global_record_count += batch_size
127+
if shard_record_count >= no_records_per_shard:
128+
writer.close()
129+
writer = tf.io.TFRecordWriter(
130+
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
131+
)
132+
shard_record_count = 0
133+
tf_rec_num +=1
134+
135+
136+
137+
def run(config):
138+
pipeline = WanPipeline.from_pretrained(config, load_transformer=False)
139+
# Don't need the transformer for preprocessing.
140+
generate_dataset(config, pipeline)
141+
142+
143+
144+
def main(argv: Sequence[str]) -> None:
145+
pyconfig.initialize(argv)
146+
run(pyconfig.config)
147+
148+
if __name__ == "__main__":
149+
app.run(main)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def load_scheduler(cls, config):
221221
return scheduler, scheduler_state
222222

223223
@classmethod
224-
def from_pretrained(cls, config: HyperParameters, vae_only=False):
224+
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
225225
devices_array = max_utils.create_device_mesh(config)
226226
mesh = Mesh(devices_array, config.mesh_axes)
227227
rng = jax.random.key(config.seed)
@@ -232,8 +232,9 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False):
232232
scheduler_state = None
233233
text_encoder = None
234234
if not vae_only:
235-
with mesh:
236-
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
235+
if load_transformer:
236+
with mesh:
237+
transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config)
237238

238239
text_encoder = cls.load_text_encoder(config=config)
239240
tokenizer = cls.load_tokenizer(config=config)

src/maxdiffusion/schedulers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
4444
_import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"]
4545
_import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"]
46-
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
46+
_import_structure["scheduling_flow_match_flax"] = ["FlaxFlowMatchScheduler"]
4747
_import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"]
4848
_import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"]
4949
_import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"]
@@ -70,6 +70,7 @@
7070
from .scheduling_ddpm_flax import FlaxDDPMScheduler
7171
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
7272
from .scheduling_euler_discrete_flax import FlaxEulerDiscreteScheduler
73+
from .scheduling_flow_match_flax import FlowMatchScheduler
7374
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
7475
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
7576
from .scheduling_pndm_flax import FlaxPNDMScheduler

0 commit comments

Comments
 (0)