diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index b75f5ceec..f03864da0 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -43,3 +43,5 @@ KEEP_1 = "activation_keep_1" KEEP_2 = "activation_keep_2" CONV_OUT = "activation_conv_out_channels" + +WAN_MODEL = "Wan2.1" diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index f3799e79f..911127896 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -234,10 +234,6 @@ num_frames: 81 guidance_scale: 5.0 flow_shift: 3.0 -# skip layer guidance -slg_layers: [9] -slg_start: 0.2 -slg_end: 1.0 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 diff --git a/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py new file mode 100644 index 000000000..c0134bab9 --- /dev/null +++ b/src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py @@ -0,0 +1,128 @@ +""" + Copyright 2025 Google LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +""" +Prepare tfrecords with latents and text embeddings preprocessed. +1. Download the dataset +""" + +import os +from absl import app +from typing import Sequence +import csv +import jax.numpy as jnp +from maxdiffusion import pyconfig + +import torch +import tensorflow as tf + + +def image_feature(value): + """Returns a bytes_list from a string / byte.""" + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])) + + +def bytes_feature(value): + """Returns a bytes_list from a string / byte.""" + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy()])) + + +def float_feature(value): + """Returns a float_list from a float / double.""" + return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) + + +def int64_feature(value): + """Returns an int64_list from a bool / enum / int / uint.""" + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def float_feature_list(value): + """Returns a list of float_list from a float / double.""" + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + + +def create_example(latent, hidden_states): + latent = tf.io.serialize_tensor(latent) + hidden_states = tf.io.serialize_tensor(hidden_states) + feature = { + "latents": bytes_feature(latent), + "encoder_hidden_states": bytes_feature(hidden_states), + } + example = tf.train.Example(features=tf.train.Features(feature=feature)) + return example.SerializeToString() + + +def generate_dataset(config): + + tfrecords_dir = config.tfrecords_dir + if not os.path.exists(tfrecords_dir): + os.makedirs(tfrecords_dir) + + tf_rec_num = 0 + no_records_per_shard = config.no_records_per_shard + global_record_count = 0 + writer = tf.io.TFRecordWriter( + tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) + ) + shard_record_count = 0 + + # Load dataset + metadata_path = os.path.join(config.train_data_dir, "metadata.csv") + with open(metadata_path, "r", newline="") as file: + # Create a csv.reader object + csv_reader = csv.reader(file) + next(csv_reader) + + # If your CSV has a header row, you can skip it + # next(csv_reader, None) + + # Iterate over each row in the CSV file + for row in csv_reader: + video_name = row[0] + pth_path = os.path.join(config.train_data_dir, "train", f"{video_name}.tensors.pth") + loaded_state_dict = torch.load(pth_path, map_location=torch.device("cpu")) + prompt_embeds = loaded_state_dict["prompt_emb"]["context"].squeeze() + latent = loaded_state_dict["latents"] + + # Format we want(Batch, channels, Frames, Height, Width) + # Save them as float32 because numpy cannot read bfloat16. + latent = jnp.array(latent.float().numpy(), dtype=jnp.float32) + prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32) + writer.write(create_example(latent, prompt_embeds)) + shard_record_count += 1 + global_record_count += 1 + + if shard_record_count >= no_records_per_shard: + writer.close() + tf_rec_num += 1 + writer = tf.io.TFRecordWriter( + tfrecords_dir + "/file_%.2i-%i.tfrec" % (tf_rec_num, (global_record_count + no_records_per_shard)) + ) + shard_record_count = 0 + + +def run(config): + generate_dataset(config) + + +def main(argv: Sequence[str]) -> None: + pyconfig.initialize(argv) + run(pyconfig.config) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index ad10cdf06..d3c8d47cf 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -29,10 +29,6 @@ def run(config, pipeline=None, filename_prefix=""): pipeline = WanPipeline.from_pretrained(config) s0 = time.perf_counter() - # Skip layer guidance - slg_layers = config.slg_layers - slg_start = config.slg_start - slg_end = config.slg_end # If global_batch_size % jax.device_count is not 0, use FSDP sharding. global_batch_size = config.global_batch_size if global_batch_size != 0: @@ -55,9 +51,6 @@ def run(config, pipeline=None, filename_prefix=""): num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, - slg_layers=slg_layers, - slg_start=slg_start, - slg_end=slg_end, ) print("compile time: ", (time.perf_counter() - s0)) @@ -76,9 +69,6 @@ def run(config, pipeline=None, filename_prefix=""): num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, - slg_layers=slg_layers, - slg_start=slg_start, - slg_end=slg_end, ) print("generation time: ", (time.perf_counter() - s0)) @@ -93,9 +83,6 @@ def run(config, pipeline=None, filename_prefix=""): num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, - slg_layers=slg_layers, - slg_start=slg_start, - slg_end=slg_end, ) max_utils.deactivate_profiler(config) print("generation time: ", (time.perf_counter() - s0)) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index a00928e3e..3099a5bc0 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -380,7 +380,6 @@ def _apply_attention( ) else: can_use_flash_attention = True - if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention: return _apply_attention_dot( query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention @@ -509,11 +508,12 @@ def __init__( heads: int, dim_head: int, use_memory_efficient_attention: bool = False, - split_head_dim: bool = False, + split_head_dim: bool = True, float32_qk_product: bool = True, axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), - flash_min_seq_length: int = 4096, + # Uses splash attention on cross attention. + flash_min_seq_length: int = 0, flash_block_sizes: BlockSizes = None, dtype: DType = jnp.float32, quant: Quant = None, @@ -674,8 +674,10 @@ def __init__( dtype=dtype, quant=quant, ) - - kernel_axes = ("embed", "heads") + # None axes corresponds to the stacked weights across all blocks + # because of the use of nnx.vmap and nnx.scan. + # Dims are [num_blocks, embed, heads] + kernel_axes = (None, "embed", "heads") qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes) self.query = nnx.Linear( @@ -686,7 +688,13 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + bias_init=nnx.with_partitioning( + nnx.initializers.zeros, + ( + None, + "embed", + ), + ), ) self.key = nnx.Linear( @@ -697,7 +705,13 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + bias_init=nnx.with_partitioning( + nnx.initializers.zeros, + ( + None, + "embed", + ), + ), ) self.value = nnx.Linear( @@ -708,14 +722,20 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + bias_init=nnx.with_partitioning( + nnx.initializers.zeros, + ( + None, + "embed", + ), + ), ) self.proj_attn = nnx.Linear( rngs=rngs, in_features=self.inner_dim, out_features=self.inner_dim, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")), dtype=dtype, param_dtype=weights_dtype, precision=precision, @@ -729,7 +749,13 @@ def __init__( rngs=rngs, epsilon=eps, dtype=dtype, - scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)), + scale_init=nnx.with_partitioning( + nnx.initializers.ones, + ( + None, + "norm", + ), + ), param_dtype=weights_dtype, ) @@ -737,7 +763,13 @@ def __init__( num_features=self.inner_dim, rngs=rngs, dtype=dtype, - scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)), + scale_init=nnx.with_partitioning( + nnx.initializers.ones, + ( + None, + "norm", + ), + ), param_dtype=weights_dtype, ) diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 409b425b0..0edab5070 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -25,6 +25,7 @@ from chex import Array from ..utils import logging from .. import max_logging +from .. import common_types logger = logging.get_logger(__name__) @@ -86,7 +87,7 @@ def rename_key(key): # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py -def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): +def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict, model_type=None): """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" # conv norm or layer norm renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) @@ -109,9 +110,17 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name) if renamed_pt_tuple_key in random_flax_state_dict: if isinstance(random_flax_state_dict[renamed_pt_tuple_key], Partitioned): - assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape + # Wan 2.1 uses nnx.scan and nnx.vmap which stacks layer weights which will cause a shape mismatch + # from the original weights which are not stacked. + if model_type is not None and model_type == common_types.WAN_MODEL: + pass + else: + assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape else: - assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape + if model_type is not None and model_type == common_types.WAN_MODEL: + pass + else: + assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape return renamed_pt_tuple_key, pt_tensor.T if ( diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index e0db9dd16..b1ae70b7a 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -359,6 +359,7 @@ def __init__( ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels + self.num_layers = num_layers # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -396,9 +397,10 @@ def __init__( ) # 3. Transformer blocks - blocks = [] - for _ in range(num_layers): - block = WanTransformerBlock( + @nnx.split_rngs(splits=num_layers) + @nnx.vmap(in_axes=0, out_axes=0) + def init_block(rngs): + return WanTransformerBlock( rngs=rngs, dim=inner_dim, ffn_dim=ffn_dim, @@ -414,8 +416,8 @@ def __init__( precision=precision, attention=attention, ) - blocks.append(block) - self.blocks = blocks + + self.blocks = init_block(rngs) self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) self.proj_out = nnx.Linear( @@ -438,8 +440,6 @@ def __call__( hidden_states: jax.Array, timestep: jax.Array, encoder_hidden_states: jax.Array, - is_uncond: jax.Array, # jnp.bool_ scalar - slg_mask: jax.Array, # jnp.bool_ array of shape (num_blocks,) encoder_hidden_states_image: Optional[jax.Array] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -463,21 +463,21 @@ def __call__( if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") - def skip_block_true(hidden_states): - split_bs = hidden_states.shape[0] // 2 - prev_neg_hidden_states = hidden_states[split_bs:] + def scan_fn(carry, block): + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb = carry hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - hidden_states = jnp.concatenate([hidden_states[:split_bs], prev_neg_hidden_states], axis=0) - return hidden_states - - for block_idx, block in enumerate(self.blocks): - should_skip_block = slg_mask[block_idx] & is_uncond - hidden_states = jax.lax.cond( - should_skip_block, - lambda _: skip_block_true(hidden_states), # If true, pass through original hidden_states (skip block) - lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb), - hidden_states, - ) + return (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + final_carry = nnx.scan( + scan_fn, + length=self.num_layers, + in_axes=(nnx.Carry, 0), + out_axes=nnx.Carry, + )(initial_carry, self.blocks) + + hidden_states = final_carry[0] + shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).astype(hidden_states.dtype) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 6623e78df..2ceb0f7e6 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -8,6 +8,7 @@ from safetensors import safe_open from flax.traverse_util import unflatten_dict, flatten_dict from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) +from ...common_types import WAN_MODEL CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid" WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX" @@ -82,9 +83,22 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + if "blocks" in pt_tuple_key: + new_key = ("blocks",) + pt_tuple_key[2:] + block_index = int(pt_tuple_key[1]) + pt_tuple_key = new_key + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL + ) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) + + if "blocks" in flax_key: + if flax_key in flax_state_dict: + new_tensor = flax_state_dict[flax_key] + else: + new_tensor = jnp.zeros((40,) + flax_tensor.shape) + flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) flax_state_dict = unflatten_dict(flax_state_dict) @@ -117,9 +131,22 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + if "blocks" in pt_tuple_key: + new_key = ("blocks",) + pt_tuple_key[2:] + block_index = int(pt_tuple_key[1]) + pt_tuple_key = new_key + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL + ) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) + + if "blocks" in flax_key: + if flax_key in flax_state_dict: + new_tensor = flax_state_dict[flax_key] + else: + new_tensor = jnp.zeros((40,) + flax_tensor.shape) + flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) flax_state_dict = unflatten_dict(flax_state_dict) @@ -196,9 +223,22 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + if "blocks" in pt_tuple_key: + new_key = ("blocks",) + pt_tuple_key[2:] + block_index = int(pt_tuple_key[1]) + pt_tuple_key = new_key + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL + ) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) + + if "blocks" in flax_key: + if flax_key in flax_state_dict: + new_tensor = flax_state_dict[flax_key] + else: + new_tensor = jnp.zeros((40,) + flax_tensor.shape) + flax_tensor = new_tensor.at[block_index].set(flax_tensor) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) validate_flax_state_dict(eval_shapes, flax_state_dict) flax_state_dict = unflatten_dict(flax_state_dict) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index ed5b84489..abf449291 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -376,9 +376,6 @@ def __call__( prompt_embeds: jax.Array = None, negative_prompt_embeds: jax.Array = None, vae_only: bool = False, - slg_layers: List[int] = None, - slg_start: float = 0.0, - slg_end: float = 1.0, ): if not vae_only: if num_frames % self.vae_scale_factor_temporal != 1: @@ -434,9 +431,6 @@ def __call__( num_inference_steps=num_inference_steps, scheduler=self.scheduler, scheduler_state=scheduler_state, - slg_layers=slg_layers, - slg_start=slg_start, - slg_end=slg_end, num_transformer_layers=self.transformer.config.num_layers, ) @@ -471,15 +465,11 @@ def transformer_forward_pass( latents, timestep, prompt_embeds, - is_uncond, - slg_mask, do_classifier_free_guidance, guidance_scale, ): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - noise_pred = wan_transformer( - hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, is_uncond=is_uncond, slg_mask=slg_mask - ) + noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) if do_classifier_free_guidance: bsz = latents.shape[0] // 2 noise_uncond = noise_pred[bsz:] @@ -502,17 +492,11 @@ def run_inference( scheduler: FlaxUniPCMultistepScheduler, num_transformer_layers: int, scheduler_state, - slg_layers: List[int] = None, - slg_start: float = 0.0, - slg_end: float = 1.0, ): do_classifier_free_guidance = guidance_scale > 1.0 if do_classifier_free_guidance: prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) for step in range(num_inference_steps): - slg_mask = jnp.zeros(num_transformer_layers, dtype=jnp.bool_) - if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps): - slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] if do_classifier_free_guidance: latents = jnp.concatenate([latents] * 2) @@ -525,8 +509,6 @@ def run_inference( latents, timestep, prompt_embeds, - is_uncond=jnp.array(True, dtype=jnp.bool_), - slg_mask=slg_mask, do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale, ) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 4ea50cc7a..01d169b01 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -269,11 +269,7 @@ def test_wan_model(self): dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) with mesh: dummy_output = wan_model( - hidden_states=dummy_hidden_states, - timestep=dummy_timestep, - encoder_hidden_states=dummy_encoder_hidden_states, - is_uncond=jnp.array(True, dtype=jnp.bool_), - slg_mask=jnp.zeros(num_layers, dtype=jnp.bool_), + hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states ) assert dummy_output.shape == hidden_states_shape diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 31ce039ad..d568709f0 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -49,11 +49,14 @@ def print_ssim(pretrained_video_path, posttrained_video_path): pretrained_video = video_processor.preprocess_video(pretrained_video) pretrained_video = np.array(pretrained_video) pretrained_video = np.transpose(pretrained_video, (0, 2, 3, 4, 1)) + pretrained_video = np.uint8((pretrained_video + 1) * 255 / 2) posttrained_video = load_video(posttrained_video_path[0]) posttrained_video = video_processor.preprocess_video(posttrained_video) posttrained_video = np.array(posttrained_video) posttrained_video = np.transpose(posttrained_video, (0, 2, 3, 4, 1)) + posttrained_video = np.uint8((posttrained_video + 1) * 255 / 2) + ssim_compare = ssim(pretrained_video[0], posttrained_video[0], multichannel=True, channel_axis=-1, data_range=255) max_logging.log(f"SSIM score after training is {ssim_compare}") @@ -162,7 +165,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera state = state.to_pure_dict() p_train_step = jax.jit( - functools.partial(train_step, scheduler=pipeline.scheduler), + functools.partial(train_step, scheduler=pipeline.scheduler, config=self.config), donate_argnums=(0,), ) rng = jax.random.key(self.config.seed) @@ -216,16 +219,16 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera return pipeline -def train_step(state, graphdef, scheduler_state, data, rng, scheduler): - return step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng) +def train_step(state, graphdef, scheduler_state, data, rng, scheduler, config): + return step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config) -def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng): +def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config): _, new_rng, timestep_rng = jax.random.split(rng, num=3) def loss_fn(model): - latents = data["latents"] - encoder_hidden_states = data["encoder_hidden_states"] + latents = data["latents"].astype(config.weights_dtype) + encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) bsz = latents.shape[0] timesteps = jax.random.randint( timestep_rng, @@ -240,8 +243,6 @@ def loss_fn(model): hidden_states=noisy_latents, timestep=timesteps, encoder_hidden_states=encoder_hidden_states, - is_uncond=jnp.array(False, dtype=jnp.bool_), - slg_mask=jnp.zeros(1, dtype=jnp.bool_), ) training_target = scheduler.training_target(latents, noise, timesteps)