From 2b8549aa81c25d084c0544ef272c6980d402d221 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 22 Jul 2025 23:20:29 +0000 Subject: [PATCH 1/9] fixes ssim. --- src/maxdiffusion/trainers/wan_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 31ce039ad..2de705a88 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -49,11 +49,14 @@ def print_ssim(pretrained_video_path, posttrained_video_path): pretrained_video = video_processor.preprocess_video(pretrained_video) pretrained_video = np.array(pretrained_video) pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1)) + pretrained_video = np.uint8(255 * pretrained_video) posttrained_video = load_video(posttrained_video_path[0]) posttrained_video = video_processor.preprocess_video(posttrained_video) posttrained_video = np.array(posttrained_video) posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1)) + posttrained_video = np.uint8(255 * posttrained_video) + ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255) max_logging.log(f"SSIM score after training is {ssim_compare}") From 220f24bcafcab4eee8e178d6036acab366f12d1d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 23 Jul 2025 02:32:30 +0000 Subject: [PATCH 2/9] adds pusav1 video dataset. --- .../wan_pusav1_to_tfrecords.py | 131 ++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py diff --git a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py new file mode 100644 index 000000000..2514bc0eb --- /dev/null +++ b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py @@ -0,0 +1,131 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +""" +Prepare tfrecords with latents and text embeddings preprocessed. +1. Download the dataset +""" + +import os +import functools +from absl import app +from typing import Sequence, Union, List +from datasets import load_dataset +import csv +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from maxdiffusion import pyconfig, max_utils +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +from maxdiffusion.video_processor import VideoProcessor + +import torch +import tensorflow as tf + + +def image_feature(value): + """Returns a bytes_list from a string / byte.""" + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])) + + +def bytes_feature(value): + """Returns a bytes_list from a string / byte.""" + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()])) + + +def float_feature(value): + """Returns a float_list from a float / double.""" + return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) + + +def int64_feature(value): + """Returns an int64_list from a bool / enum / int / uint.""" + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def float_feature_list(value): + """Returns a list of float_list from a float / double.""" + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + + +def create_example(latent, hidden_states): + latent = tf.io.serialize_tensor(latent) + hidden_states = tf.io.serialize_tensor(hidden_states) + feature = { + "latents": bytes_feature(latent), + "encoder_hidden_states": bytes_feature(hidden_states), + } + example = tf.train.Example(features=tf.train.Features(feature=feature)) + return example.SerializeToString() + +def generate_dataset(config): + + tfrecords_dir = config.tfrecords_dir + if not os.path.exists(tfrecords_dir): + os.makedirs(tfrecords_dir) + + tf_rec_num = 0 + no_records_per_shard = config.no_records_per_shard + global_record_count = 0 + writer = tf.io.TFRecordWriter( + tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) + ) + shard_record_count = 0 + + # Load dataset + metadata_path = os.path.join(config.train_data_dir, "metadata.csv") + with open(metadata_path, 'r', newline='') as file: + # Create a csv.reader object + csv_reader = csv.reader(file) + next(csv_reader) + + # If your CSV has a header row, you can skip it + # next(csv_reader, None) + + # Iterate over each row in the CSV file + for row in csv_reader: + video_name = row[0] + pth_path = os.path.join(config.train_data_dir,"train", f"{video_name}.tensors.pth") + loaded_state_dict = torch.load(pth_path, map_location=torch.device('cpu')) + prompt_embeds = loaded_state_dict["prompt_emb"]["context"] + latent = loaded_state_dict["latents"] + # Format we want(4, 16, 1, 64, 64) + latent = jnp.array(latent.float().numpy(), dtype=config.weights_dtype) + prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=config.weights_dtype) + writer.write(create_example(latent, prompt_embeds)) + shard_record_count += 1 + global_record_count += 1 + + if shard_record_count >= no_records_per_shard: + writer.close() + tf_rec_num += 1 + writer = tf.io.TFRecordWriter( + tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) + ) + shard_record_count = 0 + +def run(config): + generate_dataset(config) + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) From e9eb4ca4db4ef08510f368a2e3b7aa3b27e49a77 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 23 Jul 2025 04:39:55 +0000 Subject: [PATCH 3/9] wip - adds trainer and attn changes. --- .../data_preprocessing/wan_pusav1_to_tfrecords.py | 8 +++++--- src/maxdiffusion/models/attention_flax.py | 4 ++-- src/maxdiffusion/trainers/wan_trainer.py | 14 ++++++++------ 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py index 2514bc0eb..adec70abc 100644 --- a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py +++ b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py @@ -103,9 +103,11 @@ def generate_dataset(config): loaded_state_dict = torch.load(pth_path, map_location=torch.device('cpu')) prompt_embeds = loaded_state_dict["prompt_emb"]["context"] latent = loaded_state_dict["latents"] - # Format we want(4, 16, 1, 64, 64) - latent = jnp.array(latent.float().numpy(), dtype=config.weights_dtype) - prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=config.weights_dtype) + + # Format we want(Batch, channels, Frames, Height, Width) + # Save them as float32 because numpy cannot read bfloat16. + latent = jnp.array(latent.float().numpy(), dtype=jnp.float32) + prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32) writer.write(create_example(latent, prompt_embeds)) shard_record_count += 1 global_record_count += 1 diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index a00928e3e..7e1bc4918 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -380,7 +380,7 @@ def _apply_attention( ) else: can_use_flash_attention = True - + can_use_flash_attention=True if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention: return _apply_attention_dot( query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention @@ -509,7 +509,7 @@ def __init__( heads: int, dim_head: int, use_memory_efficient_attention: bool = False, - split_head_dim: bool = False, + split_head_dim: bool = True, float32_qk_product: bool = True, axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 2de705a88..6dd51f12a 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -165,7 +165,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera state = state.to_pure_dict() p_train_step = jax.jit( - functools.partial(train_step, scheduler=pipeline.scheduler), + functools.partial(train_step, scheduler=pipeline.scheduler, config=self.config), donate_argnums=(0,), ) rng = jax.random.key(self.config.seed) @@ -219,16 +219,18 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera return pipeline -def train_step(state, graphdef, scheduler_state, data, rng, scheduler): - return step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng) +def train_step(state, graphdef, scheduler_state, data, rng, scheduler, config): + return step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config) -def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng): +def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config): _, new_rng, timestep_rng = jax.random.split(rng, num=3) def loss_fn(model): - latents = data["latents"] - encoder_hidden_states = data["encoder_hidden_states"] + latents = data["latents"].astype(config.weights_dtype) + encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) + # TODO - fix tf record conversion. + encoder_hidden_states = jax.numpy.squeeze(encoder_hidden_states, axis=1) bsz = latents.shape[0] timesteps = jax.random.randint( timestep_rng, From aa442f93faf26ed772131b09830732b23ec60da5 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 23 Jul 2025 04:50:57 +0000 Subject: [PATCH 4/9] force splash attention for cross attention. --- src/maxdiffusion/models/attention_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 7e1bc4918..834ca9f64 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -380,7 +380,6 @@ def _apply_attention( ) else: can_use_flash_attention = True - can_use_flash_attention=True if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention: return _apply_attention_dot( query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention @@ -513,7 +512,8 @@ def __init__( float32_qk_product: bool = True, axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), - flash_min_seq_length: int = 4096, + # Uses splash attention on cross attention. + flash_min_seq_length: int = 0, flash_block_sizes: BlockSizes = None, dtype: DType = jnp.float32, quant: Quant = None, From 3d2edcca81be30f5f1d9be987144b655630ac89e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 23 Jul 2025 18:56:22 +0000 Subject: [PATCH 5/9] use nnx.scan over for loop. --- src/maxdiffusion/common_types.py | 2 + .../models/modeling_flax_pytorch_utils.py | 15 +++++-- .../wan/transformers/transformer_wan.py | 39 ++++++++++--------- src/maxdiffusion/models/wan/wan_utils.py | 18 +++++++-- 4 files changed, 49 insertions(+), 25 deletions(-) diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index b75f5ceec..f03864da0 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -43,3 +43,5 @@ KEEP_1 = "activation_keep_1" KEEP_2 = "activation_keep_2" CONV_OUT = "activation_conv_out_channels" + +WAN_MODEL = "Wan2.1" diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 409b425b0..0edab5070 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -25,6 +25,7 @@ from chex import Array from ..utils import logging from .. import max_logging +from .. import common_types logger = logging.get_logger(__name__) @@ -86,7 +87,7 @@ def rename_key(key): # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py -def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): +def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict, model_type=None): """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" # conv norm or layer norm renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) @@ -109,9 +110,17 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name) if renamed_pt_tuple_key in random_flax_state_dict: if isinstance(random_flax_state_dict[renamed_pt_tuple_key], Partitioned): - assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape + # Wan 2.1 uses nnx.scan and nnx.vmap which stacks layer weights which will cause a shape mismatch + # from the original weights which are not stacked. + if model_type is not None and model_type == common_types.WAN_MODEL: + pass + else: + assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape else: - assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape + if model_type is not None and model_type == common_types.WAN_MODEL: + pass + else: + assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape return renamed_pt_tuple_key, pt_tensor.T if ( diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index e0db9dd16..2b56e7e1b 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -359,6 +359,7 @@ def __init__( ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels + self.num_layers = num_layers # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -396,9 +397,10 @@ def __init__( ) # 3. Transformer blocks - blocks = [] - for _ in range(num_layers): - block = WanTransformerBlock( + @nnx.split_rngs(splits=num_layers) + @nnx.vmap + def init_block(rngs): + return WanTransformerBlock( rngs=rngs, dim=inner_dim, ffn_dim=ffn_dim, @@ -414,8 +416,7 @@ def __init__( precision=precision, attention=attention, ) - blocks.append(block) - self.blocks = blocks + self.blocks = init_block(rngs) self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) self.proj_out = nnx.Linear( @@ -463,21 +464,21 @@ def __call__( if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") - def skip_block_true(hidden_states): - split_bs = hidden_states.shape[0] // 2 - prev_neg_hidden_states = hidden_states[split_bs:] + def scan_fn(carry, block): + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb = carry hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - hidden_states = jnp.concatenate([hidden_states[:split_bs], prev_neg_hidden_states], axis=0) - return hidden_states - - for block_idx, block in enumerate(self.blocks): - should_skip_block = slg_mask[block_idx] & is_uncond - hidden_states = jax.lax.cond( - should_skip_block, - lambda _: skip_block_true(hidden_states), # If true, pass through original hidden_states (skip block) - lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb), - hidden_states, - ) + return (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + final_carry = nnx.scan( + scan_fn, + length=self.num_layers, + in_axes=(nnx.Carry, 0), + out_axes=nnx.Carry, + )(initial_carry, self.blocks) + + hidden_states = final_carry[0] + shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).astype(hidden_states.dtype) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 6623e78df..cd215463f 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -8,6 +8,7 @@ from safetensors import safe_open from flax.traverse_util import unflatten_dict, flatten_dict from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) +from ...common_types import WAN_MODEL CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid" WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX" @@ -82,7 +83,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) @@ -117,7 +118,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) @@ -196,9 +197,20 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + if "blocks" in pt_tuple_key: + new_key = ("blocks",) + pt_tuple_key[2:] + block_index = int(pt_tuple_key[1]) + pt_tuple_key = new_key + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) + + if "blocks" in flax_key: + if flax_key in flax_state_dict: + new_tensor = flax_state_dict[flax_key] + else: + new_tensor = jnp.zeros((40,) + flax_tensor.shape) + flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) flax_state_dict = unflatten_dict(flax_state_dict) From 34968e059ee2161f078e680e1bab0989ecef0d7b Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 23 Jul 2025 20:18:57 +0000 Subject: [PATCH 6/9] support wan transformers for nnx.scan. --- .../wan_pusav1_to_tfrecords.py | 2 +- src/maxdiffusion/models/wan/wan_utils.py | 23 +++++++++++++++++++ src/maxdiffusion/trainers/wan_trainer.py | 2 -- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py index adec70abc..d2772a262 100644 --- a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py +++ b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py @@ -101,7 +101,7 @@ def generate_dataset(config): video_name = row[0] pth_path = os.path.join(config.train_data_dir,"train", f"{video_name}.tensors.pth") loaded_state_dict = torch.load(pth_path, map_location=torch.device('cpu')) - prompt_embeds = loaded_state_dict["prompt_emb"]["context"] + prompt_embeds = loaded_state_dict["prompt_emb"]["context"].squeeze() latent = loaded_state_dict["latents"] # Format we want(Batch, channels, Frames, Height, Width) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index cd215463f..cc7498d0a 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -83,9 +83,20 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di pt_tuple_key = tuple(renamed_pt_key.split(".")) + if "blocks" in pt_tuple_key: + new_key = ("blocks",) + pt_tuple_key[2:] + block_index = int(pt_tuple_key[1]) + pt_tuple_key = new_key flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) + + if "blocks" in flax_key: + if flax_key in flax_state_dict: + new_tensor = flax_state_dict[flax_key] + else: + new_tensor = jnp.zeros((40,) + flax_tensor.shape) + flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) flax_state_dict = unflatten_dict(flax_state_dict) @@ -118,9 +129,21 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di pt_tuple_key = tuple(renamed_pt_key.split(".")) + if "blocks" in pt_tuple_key: + new_key = ("blocks",) + pt_tuple_key[2:] + block_index = int(pt_tuple_key[1]) + pt_tuple_key = new_key flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) + + + if "blocks" in flax_key: + if flax_key in flax_state_dict: + new_tensor = flax_state_dict[flax_key] + else: + new_tensor = jnp.zeros((40,) + flax_tensor.shape) + flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) flax_state_dict = unflatten_dict(flax_state_dict) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 6dd51f12a..a072c77db 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -229,8 +229,6 @@ def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, confi def loss_fn(model): latents = data["latents"].astype(config.weights_dtype) encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) - # TODO - fix tf record conversion. - encoder_hidden_states = jax.numpy.squeeze(encoder_hidden_states, axis=1) bsz = latents.shape[0] timesteps = jax.random.randint( timestep_rng, From 0df5659f1bf6dfd96b2b917f0ecc0647f444377d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 23 Jul 2025 23:09:57 +0000 Subject: [PATCH 7/9] fix ag from vmap/scan. --- .../wan_pusav1_to_tfrecords.py | 21 +++++++------------ src/maxdiffusion/models/attention_flax.py | 18 +++++++++------- .../wan/transformers/transformer_wan.py | 11 +++++----- src/maxdiffusion/models/wan/wan_utils.py | 13 ++++++++---- 4 files changed, 33 insertions(+), 30 deletions(-) diff --git a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py index d2772a262..c0134bab9 100644 --- a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py +++ b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py @@ -20,18 +20,11 @@ """ import os -import functools from absl import app -from typing import Sequence, Union, List -from datasets import load_dataset +from typing import Sequence import csv -import numpy as np -import jax import jax.numpy as jnp -from jax.sharding import Mesh -from maxdiffusion import pyconfig, max_utils -from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline -from maxdiffusion.video_processor import VideoProcessor +from maxdiffusion import pyconfig import torch import tensorflow as tf @@ -72,6 +65,7 @@ def create_example(latent, hidden_states): example = tf.train.Example(features=tf.train.Features(feature=feature)) return example.SerializeToString() + def generate_dataset(config): tfrecords_dir = config.tfrecords_dir @@ -88,7 +82,7 @@ def generate_dataset(config): # Load dataset metadata_path = os.path.join(config.train_data_dir, "metadata.csv") - with open(metadata_path, 'r', newline='') as file: + with open(metadata_path, "r", newline="") as file: # Create a csv.reader object csv_reader = csv.reader(file) next(csv_reader) @@ -99,11 +93,11 @@ def generate_dataset(config): # Iterate over each row in the CSV file for row in csv_reader: video_name = row[0] - pth_path = os.path.join(config.train_data_dir,"train", f"{video_name}.tensors.pth") - loaded_state_dict = torch.load(pth_path, map_location=torch.device('cpu')) + pth_path = os.path.join(config.train_data_dir, "train", f"{video_name}.tensors.pth") + loaded_state_dict = torch.load(pth_path, map_location=torch.device("cpu")) prompt_embeds = loaded_state_dict["prompt_emb"]["context"].squeeze() latent = loaded_state_dict["latents"] - + # Format we want(Batch, channels, Frames, Height, Width) # Save them as float32 because numpy cannot read bfloat16. latent = jnp.array(latent.float().numpy(), dtype=jnp.float32) @@ -120,6 +114,7 @@ def generate_dataset(config): ) shard_record_count = 0 + def run(config): generate_dataset(config) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 834ca9f64..06ae92c2c 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -674,8 +674,10 @@ def __init__( dtype=dtype, quant=quant, ) - - kernel_axes = ("embed", "heads") + # None axes corresponds to the stacked weights across all blocks + # because of the use of nnx.vmap and nnx.scan. + # Dims are [num_blocks, embed, heads] + kernel_axes = (None, "embed", "heads") qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes) self.query = nnx.Linear( @@ -686,7 +688,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)), ) self.key = nnx.Linear( @@ -697,7 +699,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)), ) self.value = nnx.Linear( @@ -708,14 +710,14 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)), ) self.proj_attn = nnx.Linear( rngs=rngs, in_features=self.inner_dim, out_features=self.inner_dim, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")), dtype=dtype, param_dtype=weights_dtype, precision=precision, @@ -729,7 +731,7 @@ def __init__( rngs=rngs, epsilon=eps, dtype=dtype, - scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)), + scale_init=nnx.with_partitioning(nnx.initializers.ones, (None, "norm",)), param_dtype=weights_dtype, ) @@ -737,7 +739,7 @@ def __init__( num_features=self.inner_dim, rngs=rngs, dtype=dtype, - scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)), + scale_init=nnx.with_partitioning(nnx.initializers.ones, (None, "norm",)), param_dtype=weights_dtype, ) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 2b56e7e1b..8a3b72566 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -398,7 +398,7 @@ def __init__( # 3. Transformer blocks @nnx.split_rngs(splits=num_layers) - @nnx.vmap + @nnx.vmap(in_axes=0, out_axes=0) def init_block(rngs): return WanTransformerBlock( rngs=rngs, @@ -416,6 +416,7 @@ def init_block(rngs): precision=precision, attention=attention, ) + self.blocks = init_block(rngs) self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) @@ -471,10 +472,10 @@ def scan_fn(carry, block): initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) final_carry = nnx.scan( - scan_fn, - length=self.num_layers, - in_axes=(nnx.Carry, 0), - out_axes=nnx.Carry, + scan_fn, + length=self.num_layers, + in_axes=(nnx.Carry, 0), + out_axes=nnx.Carry, )(initial_carry, self.blocks) hidden_states = final_carry[0] diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index cc7498d0a..2ceb0f7e6 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -87,7 +87,9 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di new_key = ("blocks",) + pt_tuple_key[2:] block_index = int(pt_tuple_key[1]) pt_tuple_key = new_key - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL) + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL + ) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) @@ -133,11 +135,12 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di new_key = ("blocks",) + pt_tuple_key[2:] block_index = int(pt_tuple_key[1]) pt_tuple_key = new_key - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL) + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL + ) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) - if "blocks" in flax_key: if flax_key in flax_state_dict: new_tensor = flax_state_dict[flax_key] @@ -224,7 +227,9 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d new_key = ("blocks",) + pt_tuple_key[2:] block_index = int(pt_tuple_key[1]) pt_tuple_key = new_key - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL) + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL + ) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) From 44224c2a5b1b373134b4eda6a59811f72a056ac5 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 23 Jul 2025 23:11:06 +0000 Subject: [PATCH 8/9] linting. --- src/maxdiffusion/models/attention_flax.py | 40 ++++++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 06ae92c2c..3099a5bc0 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -688,7 +688,13 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)), + bias_init=nnx.with_partitioning( + nnx.initializers.zeros, + ( + None, + "embed", + ), + ), ) self.key = nnx.Linear( @@ -699,7 +705,13 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)), + bias_init=nnx.with_partitioning( + nnx.initializers.zeros, + ( + None, + "embed", + ), + ), ) self.value = nnx.Linear( @@ -710,7 +722,13 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)), + bias_init=nnx.with_partitioning( + nnx.initializers.zeros, + ( + None, + "embed", + ), + ), ) self.proj_attn = nnx.Linear( @@ -731,7 +749,13 @@ def __init__( rngs=rngs, epsilon=eps, dtype=dtype, - scale_init=nnx.with_partitioning(nnx.initializers.ones, (None, "norm",)), + scale_init=nnx.with_partitioning( + nnx.initializers.ones, + ( + None, + "norm", + ), + ), param_dtype=weights_dtype, ) @@ -739,7 +763,13 @@ def __init__( num_features=self.inner_dim, rngs=rngs, dtype=dtype, - scale_init=nnx.with_partitioning(nnx.initializers.ones, (None, "norm",)), + scale_init=nnx.with_partitioning( + nnx.initializers.ones, + ( + None, + "norm", + ), + ), param_dtype=weights_dtype, ) From abbab5b9f070c284547b598b9a81b0a963907a11 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 24 Jul 2025 00:14:30 +0000 Subject: [PATCH 9/9] remove slg to simplify the code. --- src/maxdiffusion/configs/base_wan_14b.yml | 4 ---- src/maxdiffusion/generate_wan.py | 13 ------------ .../wan/transformers/transformer_wan.py | 2 -- .../pipelines/wan/wan_pipeline.py | 20 +------------------ .../tests/wan_transformer_test.py | 6 +----- src/maxdiffusion/trainers/wan_trainer.py | 6 ++---- 6 files changed, 4 insertions(+), 47 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index f3799e79f..911127896 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -234,10 +234,6 @@ num_frames: 81 guidance_scale: 5.0 flow_shift: 3.0 -# skip layer guidance -slg_layers: [9] -slg_start: 0.2 -slg_end: 1.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index ad10cdf06..d3c8d47cf 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -29,10 +29,6 @@ def run(config, pipeline=None, filename_prefix=""): pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() - # Skip layer guidance - slg_layers = config.slg_layers - slg_start = config.slg_start - slg_end = config.slg_end # If global_batch_size % jax.device_count is not 0, use FSDP sharding. global_batch_size = config.global_batch_size if global_batch_size != 0: @@ -55,9 +51,6 @@ def run(config, pipeline=None, filename_prefix=""): num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, - slg_layers=slg_layers, - slg_start=slg_start, - slg_end=slg_end, ) print("compile time: ", (time.perf_counter() - s0)) @@ -76,9 +69,6 @@ def run(config, pipeline=None, filename_prefix=""): num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, - slg_layers=slg_layers, - slg_start=slg_start, - slg_end=slg_end, ) print("generation time: ", (time.perf_counter() - s0)) @@ -93,9 +83,6 @@ def run(config, pipeline=None, filename_prefix=""): num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, - slg_layers=slg_layers, - slg_start=slg_start, - slg_end=slg_end, ) max_utils.deactivate_profiler(config) print("generation time: ", (time.perf_counter() - s0)) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 8a3b72566..b1ae70b7a 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -440,8 +440,6 @@ def __call__( hidden_states: jax.Array, timestep: jax.Array, encoder_hidden_states: jax.Array, - is_uncond: jax.Array, # jnp.bool_ scalar - slg_mask: jax.Array, # jnp.bool_ array of shape (num_blocks,) encoder_hidden_states_image: Optional[jax.Array] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index ed5b84489..abf449291 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -376,9 +376,6 @@ def __call__( prompt_embeds: jax.Array = None, negative_prompt_embeds: jax.Array = None, vae_only: bool = False, - slg_layers: List[int] = None, - slg_start: float = 0.0, - slg_end: float = 1.0, ): if not vae_only: if num_frames % self.vae_scale_factor_temporal != 1: @@ -434,9 +431,6 @@ def __call__( num_inference_steps=num_inference_steps, scheduler=self.scheduler, scheduler_state=scheduler_state, - slg_layers=slg_layers, - slg_start=slg_start, - slg_end=slg_end, num_transformer_layers=self.transformer.config.num_layers, ) @@ -471,15 +465,11 @@ def transformer_forward_pass( latents, timestep, prompt_embeds, - is_uncond, - slg_mask, do_classifier_free_guidance, guidance_scale, ): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - noise_pred = wan_transformer( - hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, is_uncond=is_uncond, slg_mask=slg_mask - ) + noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) if do_classifier_free_guidance: bsz = latents.shape[0] // 2 noise_uncond = noise_pred[bsz:] @@ -502,17 +492,11 @@ def run_inference( scheduler: FlaxUniPCMultistepScheduler, num_transformer_layers: int, scheduler_state, - slg_layers: List[int] = None, - slg_start: float = 0.0, - slg_end: float = 1.0, ): do_classifier_free_guidance = guidance_scale > 1.0 if do_classifier_free_guidance: prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) for step in range(num_inference_steps): - slg_mask = jnp.zeros(num_transformer_layers, dtype=jnp.bool_) - if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps): - slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] if do_classifier_free_guidance: latents = jnp.concatenate([latents] * 2) @@ -525,8 +509,6 @@ def run_inference( latents, timestep, prompt_embeds, - is_uncond=jnp.array(True, dtype=jnp.bool_), - slg_mask=slg_mask, do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale, ) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 4ea50cc7a..01d169b01 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -269,11 +269,7 @@ def test_wan_model(self): dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) with mesh: dummy_output = wan_model( - hidden_states=dummy_hidden_states, - timestep=dummy_timestep, - encoder_hidden_states=dummy_encoder_hidden_states, - is_uncond=jnp.array(True, dtype=jnp.bool_), - slg_mask=jnp.zeros(num_layers, dtype=jnp.bool_), + hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states ) assert dummy_output.shape == hidden_states_shape diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index a072c77db..d568709f0 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -49,13 +49,13 @@ def print_ssim(pretrained_video_path, posttrained_video_path): pretrained_video = video_processor.preprocess_video(pretrained_video) pretrained_video = np.array(pretrained_video) pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1)) - pretrained_video = np.uint8(255 * pretrained_video) + pretrained_video = np.uint8((pretrained_video + 1) * 255 / 2) posttrained_video = load_video(posttrained_video_path[0]) posttrained_video = video_processor.preprocess_video(posttrained_video) posttrained_video = np.array(posttrained_video) posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1)) - posttrained_video = np.uint8(255 * posttrained_video) + posttrained_video = np.uint8((posttrained_video + 1) * 255 / 2) ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255) @@ -243,8 +243,6 @@ def loss_fn(model): hidden_states=noisy_latents, timestep=timesteps, encoder_hidden_states=encoder_hidden_states, - is_uncond=jnp.array(False, dtype=jnp.bool_), - slg_mask=jnp.zeros(1, dtype=jnp.bool_), ) training_target = scheduler.training_target(latents, noise, timesteps)