Skip to content

Commit ae9c952

Browse files
Merge branch 'main' into wan_context_parallelism_inference
2 parents 3f6eb05 + 462f463 commit ae9c952

15 files changed

Lines changed: 842 additions & 103 deletions

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ log_period: 100
2929

3030
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
3131

32+
# Overrides the transformer from pretrained_model_name_or_path
33+
wan_transformer_pretrained_model_name_or_path: ''
34+
3235
unet_checkpoint: ''
3336
revision: ''
3437
# This will convert the weights to this dtype.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""
18+
Prepare tfrecords with latents and text embeddings preprocessed.
19+
1. Download the dataset
20+
"""
21+
22+
import os
23+
import functools
24+
from absl import app
25+
from typing import Sequence, Union, List
26+
from datasets import load_dataset
27+
import numpy as np
28+
import jax
29+
import jax.numpy as jnp
30+
from jax.sharding import Mesh
31+
from maxdiffusion import pyconfig, max_utils
32+
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
33+
from maxdiffusion.video_processor import VideoProcessor
34+
35+
import tensorflow as tf
36+
37+
38+
def image_feature(value):
39+
"""Returns a bytes_list from a string / byte."""
40+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()]))
41+
42+
43+
def bytes_feature(value):
44+
"""Returns a bytes_list from a string / byte."""
45+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))
46+
47+
48+
def float_feature(value):
49+
"""Returns a float_list from a float / double."""
50+
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
51+
52+
53+
def int64_feature(value):
54+
"""Returns an int64_list from a bool / enum / int / uint."""
55+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
56+
57+
58+
def float_feature_list(value):
59+
"""Returns a list of float_list from a float / double."""
60+
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
61+
62+
63+
def create_example(latent, hidden_states):
64+
latent = tf.io.serialize_tensor(latent)
65+
hidden_states = tf.io.serialize_tensor(hidden_states)
66+
feature = {
67+
"latents": bytes_feature(latent),
68+
"encoder_hidden_states": bytes_feature(hidden_states),
69+
}
70+
example = tf.train.Example(features=tf.train.Features(feature=feature))
71+
return example.SerializeToString()
72+
73+
74+
def text_encode(pipeline, prompt: Union[str, List[str]]):
75+
encoder_hidden_states = pipeline._get_t5_prompt_embeds(prompt)
76+
encoder_hidden_states = encoder_hidden_states.detach().numpy()
77+
return encoder_hidden_states
78+
79+
80+
def vae_encode(video, rng, vae, vae_cache):
81+
latent = vae.encode(video, feat_cache=vae_cache)
82+
latent = latent.latent_dist.sample(rng)
83+
return latent
84+
85+
86+
def generate_dataset(config, pipeline):
87+
88+
tfrecords_dir = config.tfrecords_dir
89+
if not os.path.exists(tfrecords_dir):
90+
os.makedirs(tfrecords_dir)
91+
92+
tf_rec_num = 0
93+
no_records_per_shard = config.no_records_per_shard
94+
global_record_count = 0
95+
writer = tf.io.TFRecordWriter(
96+
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
97+
)
98+
shard_record_count = 0
99+
100+
# create mesh
101+
devices_array = max_utils.create_device_mesh(config)
102+
mesh = Mesh(devices_array, config.mesh_axes)
103+
rng = jax.random.key(config.seed)
104+
105+
vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample)
106+
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)
107+
108+
# jit vae fun.
109+
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))
110+
111+
# Load dataset
112+
ds = load_dataset(config.dataset_name, split="train")
113+
ds = ds.shuffle(seed=config.seed)
114+
ds = ds.select_columns([config.caption_column, config.image_column])
115+
batch_size = 10
116+
for i in range(0, len(ds), batch_size):
117+
rng, new_rng = jax.random.split(rng)
118+
text = ds[i : i + batch_size]["text"]
119+
videos = ds[i : i + batch_size]["image"]
120+
121+
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width) for video in videos]
122+
video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
123+
with mesh:
124+
latents = p_vae_encode(video=video, rng=new_rng)
125+
latents = jnp.transpose(latents, (0, 4, 1, 2, 3))
126+
encoder_hidden_states = text_encode(pipeline, text)
127+
for latent, encoder_hidden_state in zip(latents, encoder_hidden_states):
128+
writer.write(create_example(latent, encoder_hidden_state))
129+
shard_record_count += 1
130+
global_record_count += 1
131+
132+
if shard_record_count >= no_records_per_shard:
133+
writer.close()
134+
tf_rec_num += 1
135+
writer = tf.io.TFRecordWriter(
136+
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
137+
)
138+
shard_record_count = 0
139+
140+
141+
def run(config):
142+
pipeline = WanPipeline.from_pretrained(config, load_transformer=False)
143+
# Don't need the transformer for preprocessing.
144+
generate_dataset(config, pipeline)
145+
146+
147+
def main(argv: Sequence[str]) -> None:
148+
pyconfig.initialize(argv)
149+
run(pyconfig.config)
150+
151+
152+
if __name__ == "__main__":
153+
app.run(main)

src/maxdiffusion/generate_wan.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616
import jax
1717
import time
1818
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
19-
from maxdiffusion import pyconfig, max_logging
19+
from maxdiffusion import pyconfig, max_logging, max_utils
2020
from absl import app
2121
from maxdiffusion.utils import export_to_video
2222

2323
jax.config.update('jax_use_shardy_partitioner', True)
2424

25-
2625
def run(config, pipeline=None, filename_prefix=""):
2726
print("seed: ", config.seed)
2827
if pipeline is None:
@@ -61,8 +60,12 @@ def run(config, pipeline=None, filename_prefix=""):
6160
)
6261

6362
print("compile time: ", (time.perf_counter() - s0))
63+
saved_video_path = []
6464
for i in range(len(videos)):
65-
export_to_video(videos[i], f"{filename_prefix}wan_output_{config.seed}_{i}.mp4", fps=config.fps)
65+
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
66+
export_to_video(videos[i], video_path, fps=config.fps)
67+
saved_video_path.append(video_path)
68+
6669
s0 = time.perf_counter()
6770
videos = pipeline(
6871
prompt=prompt,
@@ -76,12 +79,11 @@ def run(config, pipeline=None, filename_prefix=""):
7679
slg_start=slg_start,
7780
slg_end=slg_end,
7881
)
79-
print("generation time: ", (time.perf_counter() - s0))
80-
for i in range(len(videos)):
81-
export_to_video(videos[i], f"wan_output_{config.seed}_{i}.mp4", fps=config.fps)
82+
print("compile time: ", (time.perf_counter() - s0))
8283

8384
s0 = time.perf_counter()
84-
with jax.profiler.trace("/tmp/trace/"):
85+
if config.enable_profiler:
86+
max_utils.activate_profiler(config)
8587
videos = pipeline(
8688
prompt=prompt,
8789
negative_prompt=negative_prompt,
@@ -94,7 +96,9 @@ def run(config, pipeline=None, filename_prefix=""):
9496
slg_start=slg_start,
9597
slg_end=slg_end,
9698
)
97-
print("generation time: ", (time.perf_counter() - s0))
99+
max_utils.deactivate_profiler(config)
100+
print("generation time: ", (time.perf_counter() - s0))
101+
return saved_video_path
98102

99103

100104
def main(argv: Sequence[str]) -> None:

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -73,43 +73,26 @@ def make_tf_iterator(
7373
train_iter = multihost_dataloading.MultiHostDataLoadIterator(train_ds, mesh)
7474
return train_iter
7575

76+
7677
def make_cached_tfrecord_iterator(
77-
config,
78-
dataloading_host_index,
79-
dataloading_host_count,
80-
mesh,
81-
global_batch_size,
78+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
8279
):
8380
"""
8481
New iterator for TFRecords that contain the full 4 pre-computed latents and embeddings:
8582
latents, input_ids, prompt_embeds, and text_embeds.
8683
"""
87-
feature_description = {
88-
"pixel_values": tf.io.FixedLenFeature([], tf.string),
89-
"input_ids": tf.io.FixedLenFeature([], tf.string),
90-
"prompt_embeds": tf.io.FixedLenFeature([], tf.string),
91-
"text_embeds": tf.io.FixedLenFeature([], tf.string),
92-
}
9384

9485
def _parse_tfrecord_fn(example):
9586
return tf.io.parse_single_example(example, feature_description)
9687

97-
def prepare_sample(features):
98-
pixel_values = tf.io.parse_tensor(features["pixel_values"], out_type=tf.float32)
99-
input_ids = tf.io.parse_tensor(features["input_ids"], out_type=tf.int32)
100-
prompt_embeds = tf.io.parse_tensor(features["prompt_embeds"], out_type=tf.float32)
101-
text_embeds = tf.io.parse_tensor(features["text_embeds"], out_type=tf.float32)
102-
103-
return {"pixel_values": pixel_values, "input_ids": input_ids, "prompt_embeds": prompt_embeds, "text_embeds": text_embeds}
104-
10588
# This pipeline reads the sharded files and applies the parsing and preparation.
10689
filenames = tf.io.gfile.glob(os.path.join(config.train_data_dir, "*"))
10790

10891
train_ds = (
10992
tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
11093
.shard(num_shards=dataloading_host_count, index=dataloading_host_index)
11194
.map(_parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
112-
.map(prepare_sample, num_parallel_calls=AUTOTUNE)
95+
.map(prepare_sample_fn, num_parallel_calls=AUTOTUNE)
11396
.shuffle(global_batch_size * 10)
11497
.batch(global_batch_size // dataloading_host_count, drop_remainder=True)
11598
.repeat(-1)
@@ -123,11 +106,7 @@ def prepare_sample(features):
123106

124107
# TODO - https://github.com/google/array_record/blob/main/beam/examples/example_gcs_conversion.py
125108
def make_tfrecord_iterator(
126-
config,
127-
dataloading_host_index,
128-
dataloading_host_count,
129-
mesh,
130-
global_batch_size,
109+
config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size, feature_description, prepare_sample_fn
131110
):
132111
"""Iterator for TFRecord format. For Laion dataset,
133112
check out preparation script
@@ -136,12 +115,22 @@ def make_tfrecord_iterator(
136115

137116
# set load_tfrecord_cached to True in config to use pre-processed tfrecord dataset.
138117
# pedagogical_examples/dataset_tf_cache_to_tfrecord.py to convert tf preprocessed dataset to tfrecord.
139-
# Datset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
140-
if (config.cache_latents_text_encoder_outputs
118+
# Dataset cache in github runner test doesn't contain all the features since its shared, Use the default tfrecord iterator.
119+
if (
120+
config.cache_latents_text_encoder_outputs
141121
and os.path.isdir(config.dataset_save_location)
142-
and 'load_tfrecord_cached'in config.get_keys()
143-
and config.load_tfrecord_cached):
144-
return make_cached_tfrecord_iterator(config, dataloading_host_index, dataloading_host_count, mesh, global_batch_size)
122+
and "load_tfrecord_cached" in config.get_keys()
123+
and config.load_tfrecord_cached
124+
):
125+
return make_cached_tfrecord_iterator(
126+
config,
127+
dataloading_host_index,
128+
dataloading_host_count,
129+
mesh,
130+
global_batch_size,
131+
feature_description,
132+
prepare_sample_fn,
133+
)
145134

146135
feature_description = {
147136
"moments": tf.io.FixedLenFeature([], tf.string),

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,25 @@ def make_data_iterator(
5050
global_batch_size,
5151
tokenize_fn=None,
5252
image_transforms_fn=None,
53+
feature_description=None,
54+
prepare_sample_fn=None,
5355
):
5456
"""Make data iterator for SD1, 2, XL, dataset_types in (hf, tf, tfrecord)"""
57+
58+
if config.dataset_type == "hf" or config.dataset_type == "tf":
59+
if tokenize_fn is None or image_transforms_fn is None:
60+
raise ValueError(f"dataset type {config.dataset_type} needs to pass a tokenize_fn and image_transforms_fn")
61+
62+
if (
63+
config.dataset_type == "tfrecord"
64+
and config.cache_latents_text_encoder_outputs
65+
and feature_description is None
66+
and prepare_sample_fn is None
67+
):
68+
raise ValueError(
69+
f"dataset type {config.dataset_type} needs to pass a feature_description dictionary and prepare_sample_fn function when cache_latents_text_encoder_outputs is True."
70+
)
71+
5572
if config.dataset_type == "hf":
5673
return _hf_data_processing.make_hf_streaming_iterator(
5774
config,
@@ -87,6 +104,8 @@ def make_data_iterator(
87104
dataloading_host_count,
88105
mesh,
89106
global_batch_size,
107+
feature_description,
108+
prepare_sample_fn,
90109
)
91110
else:
92111
assert False, f"Unknown dataset_type {config.dataset_type}, dataset_type must be in (tf, tfrecord, hf, grain)"

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def _tpu_flash_attention(
187187
axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel)
188188
named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel)
189189

190-
cp_size=1
190+
shard_head_size=mesh.shape['tensor']
191191

192192
@functools.partial(
193193
jax.jit,
@@ -200,12 +200,11 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
200200
splash_kernel = splash_attention_kernel.make_splash_mha(
201201
mask=multi_head_mask,
202202
head_shards=shard_head_size, # the sizes of the axis is sharding over heads
203-
q_seq_shards=cp_size,
203+
q_seq_shards=num_fsdp_shards,
204204
block_sizes=block_sizes,
205205
)
206206
return splash_kernel
207207

208-
shard_head_size = mesh.shape["tensor"]
209208
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
210209
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
211210
splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size))
@@ -223,10 +222,7 @@ def wrap_splash_kernel(multi_head_mask, shard_head_size=1):
223222
check_rep=False
224223
)
225224
def wrap_flash_attention(query, key, value, splash_kernel):
226-
#full_k = jax.lax.all_to_all(key, axis_name='fsdp', split_axis=2, concat_axis=2, tiled=True)
227-
#full_v = jax.lax.all_to_all(value, axis_name='fsdp', split_axis=2, concat_axis=2, tiled=True)
228225
attention_output = jax.vmap(splash_kernel)(query, key, value)
229-
#attention_output = jax.vmap(splash_kernel)(query, full_k, full_v)
230226
return attention_output
231227

232228
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]

0 commit comments

Comments
 (0)