Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,5 @@
KEEP_1 = "activation_keep_1"
KEEP_2 = "activation_keep_2"
CONV_OUT = "activation_conv_out_channels"

WAN_MODEL = "Wan2.1"
4 changes: 0 additions & 4 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,6 @@ num_frames: 81
guidance_scale: 5.0
flow_shift: 3.0

# skip layer guidance
slg_layers: [9]
slg_start: 0.2
slg_end: 1.0
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 30
Expand Down
128 changes: 128 additions & 0 deletions src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""
Prepare tfrecords with latents and text embeddings preprocessed.
1. Download the dataset
"""

import os
from absl import app
from typing import Sequence
import csv
import jax.numpy as jnp
from maxdiffusion import pyconfig

import torch
import tensorflow as tf


def image_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()]))


def bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()]))


def float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def float_feature_list(value):
"""Returns a list of float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_example(latent, hidden_states):
latent = tf.io.serialize_tensor(latent)
hidden_states = tf.io.serialize_tensor(hidden_states)
feature = {
"latents": bytes_feature(latent),
"encoder_hidden_states": bytes_feature(hidden_states),
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
return example.SerializeToString()


def generate_dataset(config):

tfrecords_dir = config.tfrecords_dir
if not os.path.exists(tfrecords_dir):
os.makedirs(tfrecords_dir)

tf_rec_num = 0
no_records_per_shard = config.no_records_per_shard
global_record_count = 0
writer = tf.io.TFRecordWriter(
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
)
shard_record_count = 0

# Load dataset
metadata_path = os.path.join(config.train_data_dir, "metadata.csv")
with open(metadata_path, "r", newline="") as file:
# Create a csv.reader object
csv_reader = csv.reader(file)
next(csv_reader)

# If your CSV has a header row, you can skip it
# next(csv_reader, None)

# Iterate over each row in the CSV file
for row in csv_reader:
video_name = row[0]
pth_path = os.path.join(config.train_data_dir, "train", f"{video_name}.tensors.pth")
loaded_state_dict = torch.load(pth_path, map_location=torch.device("cpu"))
prompt_embeds = loaded_state_dict["prompt_emb"]["context"].squeeze()
latent = loaded_state_dict["latents"]

# Format we want(Batch, channels, Frames, Height, Width)
# Save them as float32 because numpy cannot read bfloat16.
latent = jnp.array(latent.float().numpy(), dtype=jnp.float32)
prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32)
writer.write(create_example(latent, prompt_embeds))
shard_record_count += 1
global_record_count += 1

if shard_record_count >= no_records_per_shard:
writer.close()
tf_rec_num += 1
writer = tf.io.TFRecordWriter(
tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard))
)
shard_record_count = 0


def run(config):
generate_dataset(config)


def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
run(pyconfig.config)


if __name__ == "__main__":
app.run(main)
13 changes: 0 additions & 13 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ def run(config, pipeline=None, filename_prefix=""):
pipeline = WanPipeline.from_pretrained(config)
s0 = time.perf_counter()

# Skip layer guidance
slg_layers = config.slg_layers
slg_start = config.slg_start
slg_end = config.slg_end
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
global_batch_size = config.global_batch_size
if global_batch_size != 0:
Expand All @@ -55,9 +51,6 @@ def run(config, pipeline=None, filename_prefix=""):
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
slg_layers=slg_layers,
slg_start=slg_start,
slg_end=slg_end,
)

print("compile time: ", (time.perf_counter() - s0))
Expand All @@ -76,9 +69,6 @@ def run(config, pipeline=None, filename_prefix=""):
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
slg_layers=slg_layers,
slg_start=slg_start,
slg_end=slg_end,
)
print("generation time: ", (time.perf_counter() - s0))

Expand All @@ -93,9 +83,6 @@ def run(config, pipeline=None, filename_prefix=""):
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
slg_layers=slg_layers,
slg_start=slg_start,
slg_end=slg_end,
)
max_utils.deactivate_profiler(config)
print("generation time: ", (time.perf_counter() - s0))
Expand Down
54 changes: 43 additions & 11 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def _apply_attention(
)
else:
can_use_flash_attention = True

if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention:
return _apply_attention_dot(
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
Expand Down Expand Up @@ -509,11 +508,12 @@ def __init__(
heads: int,
dim_head: int,
use_memory_efficient_attention: bool = False,
split_head_dim: bool = False,
split_head_dim: bool = True,
float32_qk_product: bool = True,
axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),
flash_min_seq_length: int = 4096,
# Uses splash attention on cross attention.
flash_min_seq_length: int = 0,
flash_block_sizes: BlockSizes = None,
dtype: DType = jnp.float32,
quant: Quant = None,
Expand Down Expand Up @@ -674,8 +674,10 @@ def __init__(
dtype=dtype,
quant=quant,
)

kernel_axes = ("embed", "heads")
# None axes corresponds to the stacked weights across all blocks
# because of the use of nnx.vmap and nnx.scan.
# Dims are [num_blocks, embed, heads]
kernel_axes = (None, "embed", "heads")
qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes)

self.query = nnx.Linear(
Expand All @@ -686,7 +688,13 @@ def __init__(
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros,
(
None,
"embed",
),
),
)

self.key = nnx.Linear(
Expand All @@ -697,7 +705,13 @@ def __init__(
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros,
(
None,
"embed",
),
),
)

self.value = nnx.Linear(
Expand All @@ -708,14 +722,20 @@ def __init__(
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
bias_init=nnx.with_partitioning(
nnx.initializers.zeros,
(
None,
"embed",
),
),
)

self.proj_attn = nnx.Linear(
rngs=rngs,
in_features=self.inner_dim,
out_features=self.inner_dim,
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")),
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")),
dtype=dtype,
param_dtype=weights_dtype,
precision=precision,
Expand All @@ -729,15 +749,27 @@ def __init__(
rngs=rngs,
epsilon=eps,
dtype=dtype,
scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)),
scale_init=nnx.with_partitioning(
nnx.initializers.ones,
(
None,
"norm",
),
),
param_dtype=weights_dtype,
)

self.norm_k = nnx.RMSNorm(
num_features=self.inner_dim,
rngs=rngs,
dtype=dtype,
scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)),
scale_init=nnx.with_partitioning(
nnx.initializers.ones,
(
None,
"norm",
),
),
param_dtype=weights_dtype,
)

Expand Down
15 changes: 12 additions & 3 deletions src/maxdiffusion/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from chex import Array
from ..utils import logging
from .. import max_logging
from .. import common_types


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -86,7 +87,7 @@ def rename_key(key):

# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict, model_type=None):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
Expand All @@ -109,9 +110,17 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
if renamed_pt_tuple_key in random_flax_state_dict:
if isinstance(random_flax_state_dict[renamed_pt_tuple_key], Partitioned):
assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape
# Wan 2.1 uses nnx.scan and nnx.vmap which stacks layer weights which will cause a shape mismatch
# from the original weights which are not stacked.
if model_type is not None and model_type == common_types.WAN_MODEL:
pass
else:
assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape
else:
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
if model_type is not None and model_type == common_types.WAN_MODEL:
pass
else:
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
return renamed_pt_tuple_key, pt_tensor.T

if (
Expand Down
Loading
Loading