Skip to content

Commit 1e1d2e1

Browse files
authored
Video dataset (#209)
* fixes ssim. * adds pusav1 video dataset. * wip - adds trainer and attn changes. * force splash attention for cross attention. * use nnx.scan over for loop. * support wan transformers for nnx.scan. * fix ag from vmap/scan. * linting. * remove slg to simplify the code.
1 parent db4caf0 commit 1e1d2e1

11 files changed

Lines changed: 260 additions & 87 deletions

File tree

src/maxdiffusion/common_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,5 @@
4343
KEEP_1 = "activation_keep_1"
4444
KEEP_2 = "activation_keep_2"
4545
CONV_OUT = "activation_conv_out_channels"
46+
47+
WAN_MODEL = "Wan2.1"

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,6 @@ num_frames: 81
234234
guidance_scale: 5.0
235235
flow_shift: 3.0
236236

237-
# skip layer guidance
238-
slg_layers: [9]
239-
slg_start: 0.2
240-
slg_end: 1.0
241237
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
242238
guidance_rescale: 0.0
243239
num_inference_steps: 30
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
"""
18+
Prepare tfrecords with latents and text embeddings preprocessed.
19+
1. Download the dataset
20+
"""
21+
22+
import os
23+
from absl import app
24+
from typing import Sequence
25+
import csv
26+
import jax.numpy as jnp
27+
from maxdiffusion import pyconfig
28+
29+
import torch
30+
import tensorflow as tf
31+
32+
33+
def image_feature(value):
34+
"""Returns a bytes_list from a string / byte."""
35+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()]))
36+
37+
38+
def bytes_feature(value):
39+
"""Returns a bytes_list from a string / byte."""
40+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))
41+
42+
43+
def float_feature(value):
44+
"""Returns a float_list from a float / double."""
45+
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
46+
47+
48+
def int64_feature(value):
49+
"""Returns an int64_list from a bool / enum / int / uint."""
50+
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
51+
52+
53+
def float_feature_list(value):
54+
"""Returns a list of float_list from a float / double."""
55+
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
56+
57+
58+
def create_example(latent, hidden_states):
59+
latent = tf.io.serialize_tensor(latent)
60+
hidden_states = tf.io.serialize_tensor(hidden_states)
61+
feature = {
62+
"latents": bytes_feature(latent),
63+
"encoder_hidden_states": bytes_feature(hidden_states),
64+
}
65+
example = tf.train.Example(features=tf.train.Features(feature=feature))
66+
return example.SerializeToString()
67+
68+
69+
def generate_dataset(config):
70+
71+
tfrecords_dir = config.tfrecords_dir
72+
if not os.path.exists(tfrecords_dir):
73+
os.makedirs(tfrecords_dir)
74+
75+
tf_rec_num = 0
76+
no_records_per_shard = config.no_records_per_shard
77+
global_record_count = 0
78+
writer = tf.io.TFRecordWriter(
79+
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
80+
)
81+
shard_record_count = 0
82+
83+
# Load dataset
84+
metadata_path = os.path.join(config.train_data_dir, "metadata.csv")
85+
with open(metadata_path, "r", newline="") as file:
86+
# Create a csv.reader object
87+
csv_reader = csv.reader(file)
88+
next(csv_reader)
89+
90+
# If your CSV has a header row, you can skip it
91+
# next(csv_reader, None)
92+
93+
# Iterate over each row in the CSV file
94+
for row in csv_reader:
95+
video_name = row[0]
96+
pth_path = os.path.join(config.train_data_dir, "train", f"{video_name}.tensors.pth")
97+
loaded_state_dict = torch.load(pth_path, map_location=torch.device("cpu"))
98+
prompt_embeds = loaded_state_dict["prompt_emb"]["context"].squeeze()
99+
latent = loaded_state_dict["latents"]
100+
101+
# Format we want(Batch, channels, Frames, Height, Width)
102+
# Save them as float32 because numpy cannot read bfloat16.
103+
latent = jnp.array(latent.float().numpy(), dtype=jnp.float32)
104+
prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32)
105+
writer.write(create_example(latent, prompt_embeds))
106+
shard_record_count += 1
107+
global_record_count += 1
108+
109+
if shard_record_count >= no_records_per_shard:
110+
writer.close()
111+
tf_rec_num += 1
112+
writer = tf.io.TFRecordWriter(
113+
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
114+
)
115+
shard_record_count = 0
116+
117+
118+
def run(config):
119+
generate_dataset(config)
120+
121+
122+
def main(argv: Sequence[str]) -> None:
123+
pyconfig.initialize(argv)
124+
run(pyconfig.config)
125+
126+
127+
if __name__ == "__main__":
128+
app.run(main)

src/maxdiffusion/generate_wan.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ def run(config, pipeline=None, filename_prefix=""):
2929
pipeline = WanPipeline.from_pretrained(config)
3030
s0 = time.perf_counter()
3131

32-
# Skip layer guidance
33-
slg_layers = config.slg_layers
34-
slg_start = config.slg_start
35-
slg_end = config.slg_end
3632
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
3733
global_batch_size = config.global_batch_size
3834
if global_batch_size != 0:
@@ -55,9 +51,6 @@ def run(config, pipeline=None, filename_prefix=""):
5551
num_frames=config.num_frames,
5652
num_inference_steps=config.num_inference_steps,
5753
guidance_scale=config.guidance_scale,
58-
slg_layers=slg_layers,
59-
slg_start=slg_start,
60-
slg_end=slg_end,
6154
)
6255

6356
print("compile time: ", (time.perf_counter() - s0))
@@ -76,9 +69,6 @@ def run(config, pipeline=None, filename_prefix=""):
7669
num_frames=config.num_frames,
7770
num_inference_steps=config.num_inference_steps,
7871
guidance_scale=config.guidance_scale,
79-
slg_layers=slg_layers,
80-
slg_start=slg_start,
81-
slg_end=slg_end,
8272
)
8373
print("generation time: ", (time.perf_counter() - s0))
8474

@@ -93,9 +83,6 @@ def run(config, pipeline=None, filename_prefix=""):
9383
num_frames=config.num_frames,
9484
num_inference_steps=config.num_inference_steps,
9585
guidance_scale=config.guidance_scale,
96-
slg_layers=slg_layers,
97-
slg_start=slg_start,
98-
slg_end=slg_end,
9986
)
10087
max_utils.deactivate_profiler(config)
10188
print("generation time: ", (time.perf_counter() - s0))

src/maxdiffusion/models/attention_flax.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ def _apply_attention(
380380
)
381381
else:
382382
can_use_flash_attention = True
383-
384383
if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention:
385384
return _apply_attention_dot(
386385
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
@@ -509,11 +508,12 @@ def __init__(
509508
heads: int,
510509
dim_head: int,
511510
use_memory_efficient_attention: bool = False,
512-
split_head_dim: bool = False,
511+
split_head_dim: bool = True,
513512
float32_qk_product: bool = True,
514513
axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
515514
axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
516-
flash_min_seq_length: int = 4096,
515+
# Uses splash attention on cross attention.
516+
flash_min_seq_length: int = 0,
517517
flash_block_sizes: BlockSizes = None,
518518
dtype: DType = jnp.float32,
519519
quant: Quant = None,
@@ -674,8 +674,10 @@ def __init__(
674674
dtype=dtype,
675675
quant=quant,
676676
)
677-
678-
kernel_axes = ("embed", "heads")
677+
# None axes corresponds to the stacked weights across all blocks
678+
# because of the use of nnx.vmap and nnx.scan.
679+
# Dims are [num_blocks, embed, heads]
680+
kernel_axes = (None, "embed", "heads")
679681
qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes)
680682

681683
self.query = nnx.Linear(
@@ -686,7 +688,13 @@ def __init__(
686688
dtype=dtype,
687689
param_dtype=weights_dtype,
688690
precision=precision,
689-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
691+
bias_init=nnx.with_partitioning(
692+
nnx.initializers.zeros,
693+
(
694+
None,
695+
"embed",
696+
),
697+
),
690698
)
691699

692700
self.key = nnx.Linear(
@@ -697,7 +705,13 @@ def __init__(
697705
dtype=dtype,
698706
param_dtype=weights_dtype,
699707
precision=precision,
700-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
708+
bias_init=nnx.with_partitioning(
709+
nnx.initializers.zeros,
710+
(
711+
None,
712+
"embed",
713+
),
714+
),
701715
)
702716

703717
self.value = nnx.Linear(
@@ -708,14 +722,20 @@ def __init__(
708722
dtype=dtype,
709723
param_dtype=weights_dtype,
710724
precision=precision,
711-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
725+
bias_init=nnx.with_partitioning(
726+
nnx.initializers.zeros,
727+
(
728+
None,
729+
"embed",
730+
),
731+
),
712732
)
713733

714734
self.proj_attn = nnx.Linear(
715735
rngs=rngs,
716736
in_features=self.inner_dim,
717737
out_features=self.inner_dim,
718-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")),
738+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")),
719739
dtype=dtype,
720740
param_dtype=weights_dtype,
721741
precision=precision,
@@ -729,15 +749,27 @@ def __init__(
729749
rngs=rngs,
730750
epsilon=eps,
731751
dtype=dtype,
732-
scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)),
752+
scale_init=nnx.with_partitioning(
753+
nnx.initializers.ones,
754+
(
755+
None,
756+
"norm",
757+
),
758+
),
733759
param_dtype=weights_dtype,
734760
)
735761

736762
self.norm_k = nnx.RMSNorm(
737763
num_features=self.inner_dim,
738764
rngs=rngs,
739765
dtype=dtype,
740-
scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)),
766+
scale_init=nnx.with_partitioning(
767+
nnx.initializers.ones,
768+
(
769+
None,
770+
"norm",
771+
),
772+
),
741773
param_dtype=weights_dtype,
742774
)
743775

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from chex import Array
2626
from ..utils import logging
2727
from .. import max_logging
28+
from .. import common_types
2829

2930

3031
logger = logging.get_logger(__name__)
@@ -86,7 +87,7 @@ def rename_key(key):
8687

8788
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
8889
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
89-
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
90+
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict, model_type=None):
9091
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
9192
# conv norm or layer norm
9293
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
@@ -109,9 +110,17 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic
109110
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
110111
if renamed_pt_tuple_key in random_flax_state_dict:
111112
if isinstance(random_flax_state_dict[renamed_pt_tuple_key], Partitioned):
112-
assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape
113+
# Wan 2.1 uses nnx.scan and nnx.vmap which stacks layer weights which will cause a shape mismatch
114+
# from the original weights which are not stacked.
115+
if model_type is not None and model_type == common_types.WAN_MODEL:
116+
pass
117+
else:
118+
assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape
113119
else:
114-
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
120+
if model_type is not None and model_type == common_types.WAN_MODEL:
121+
pass
122+
else:
123+
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
115124
return renamed_pt_tuple_key, pt_tensor.T
116125

117126
if (

0 commit comments

Comments
 (0)