Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ replicate_vae: False
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
# at the cost of time.
precision: "DEFAULT"
# Use jax.lax.scan for transformer layers
scan_layers: True

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
AUTOTUNE = tf.data.AUTOTUNE
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count):
dataset = dataset.with_format("tensorflow")[:]
tf_dataset = tf.data.Dataset.from_tensor_slices(dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
AUTOTUNE = tf.data.experimental.AUTOTUNE
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def make_data_iterator(
config,
dataloading_host_index,
Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/models/gradient_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def apply(
names_which_can_be_saved: list = [],
names_which_can_be_offloaded: list = [],
static_argnums=(),
prevent_cse: bool = False,
) -> nnx.Module:
"""
Applies a gradient checkpoint policy to a module
Expand All @@ -99,4 +100,4 @@ def apply(
policy = self.to_jax_policy(names_which_can_be_saved, names_which_can_be_offloaded)
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
return module
return nnx.remat(module, prevent_cse=False, policy=policy, static_argnums=static_argnums) # pylint: disable=invalid-name
return nnx.remat(module, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) # pylint: disable=invalid-name
7 changes: 3 additions & 4 deletions src/maxdiffusion/models/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from chex import Array
from ..utils import logging
from .. import max_logging
from .. import common_types


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -87,7 +86,7 @@ def rename_key(key):

# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict, model_type=None):
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict, scan_layers=False):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
Expand All @@ -112,12 +111,12 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic
if isinstance(random_flax_state_dict[renamed_pt_tuple_key], Partitioned):
# Wan 2.1 uses nnx.scan and nnx.vmap which stacks layer weights which will cause a shape mismatch
# from the original weights which are not stacked.
if model_type is not None and model_type == common_types.WAN_MODEL:
if scan_layers:
pass
else:
assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape
else:
if model_type is not None and model_type == common_types.WAN_MODEL:
if scan_layers:
pass
else:
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
Expand Down
75 changes: 56 additions & 19 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,12 @@ def __init__(
remat_policy: str = "None",
names_which_can_be_saved: list = [],
names_which_can_be_offloaded: list = [],
scan_layers: bool = True,
):
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
self.num_layers = num_layers
self.scan_layers = scan_layers

# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
Expand Down Expand Up @@ -455,8 +457,29 @@ def init_block(rngs):
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
self.names_which_can_be_offloaded = names_which_can_be_offloaded
self.names_which_can_be_saved = names_which_can_be_saved

self.blocks = init_block(rngs)
if scan_layers:
self.blocks = init_block(rngs)
else:
blocks = nnx.List([])
for _ in range(num_layers):
block = WanTransformerBlock(
rngs=rngs,
dim=inner_dim,
ffn_dim=ffn_dim,
num_heads=num_attention_heads,
qk_norm=qk_norm,
cross_attn_norm=cross_attn_norm,
eps=eps,
flash_min_seq_length=flash_min_seq_length,
flash_block_sizes=flash_block_sizes,
mesh=mesh,
dtype=dtype,
weights_dtype=weights_dtype,
precision=precision,
attention=attention,
)
blocks.append(block)
self.blocks = blocks

self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False)
self.proj_out = nnx.Linear(
Expand Down Expand Up @@ -505,24 +528,38 @@ def __call__(
if encoder_hidden_states_image is not None:
raise NotImplementedError("img2vid is not yet implemented.")

def scan_fn(carry, block):
hidden_states_carry, rngs_carry = carry
hidden_states = block(hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry)
new_carry = (hidden_states, rngs_carry)
return new_carry, None
if self.scan_layers:

rematted_block_forward = self.gradient_checkpoint.apply(
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded
)
initial_carry = (hidden_states, rngs)
final_carry, _ = nnx.scan(
rematted_block_forward,
length=self.num_layers,
in_axes=(nnx.Carry, 0),
out_axes=(nnx.Carry, 0),
)(initial_carry, self.blocks)

hidden_states, _ = final_carry
def scan_fn(carry, block):
hidden_states_carry, rngs_carry = carry
hidden_states = block(
hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry
)
new_carry = (hidden_states, rngs_carry)
return new_carry, None

rematted_block_forward = self.gradient_checkpoint.apply(
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
)
initial_carry = (hidden_states, rngs)
final_carry, _ = nnx.scan(
rematted_block_forward,
length=self.num_layers,
in_axes=(nnx.Carry, 0),
out_axes=(nnx.Carry, 0),
)(initial_carry, self.blocks)

hidden_states, _ = final_carry
else:
for block in self.blocks:

def layer_forward(hidden_states):
return block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs)

rematted_layer_forward = self.gradient_checkpoint.apply(
layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
)
hidden_states = rematted_layer_forward(hidden_states)

shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)

Expand Down
115 changes: 57 additions & 58 deletions src/maxdiffusion/models/wan/wan_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from safetensors import safe_open
from flax.traverse_util import unflatten_dict, flatten_dict
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
from ...common_types import WAN_MODEL

CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX"
Expand Down Expand Up @@ -73,8 +72,35 @@ def rename_for_custom_trasformer(key):
return renamed_pt_key


def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers):
if scan_layers:
if "blocks" in pt_tuple_key:
new_key = ("blocks",) + pt_tuple_key[2:]
block_index = int(pt_tuple_key[1])
pt_tuple_key = new_key

flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)

flax_key = rename_for_nnx(flax_key)
flax_key = _tuple_str_to_int(flax_key)

if scan_layers:
if "blocks" in flax_key:
if flax_key in flax_state_dict:
new_tensor = flax_state_dict[flax_key]
else:
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
return flax_key, flax_tensor


def load_fusionx_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
pretrained_model_name_or_path: str,
eval_shapes: dict,
device: str,
hf_download: bool = True,
num_layers: int = 40,
scan_layers: bool = True,
):
device = jax.local_devices(backend=device)[0]
with jax.default_device(device):
Expand All @@ -101,23 +127,9 @@ def load_fusionx_transformer(

pt_tuple_key = tuple(renamed_pt_key.split("."))

if "blocks" in pt_tuple_key:
new_key = ("blocks",) + pt_tuple_key[2:]
block_index = int(pt_tuple_key[1])
pt_tuple_key = new_key
flax_key, flax_tensor = rename_key_and_reshape_tensor(
pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL
)
flax_key = rename_for_nnx(flax_key)
flax_key = _tuple_str_to_int(flax_key)

if "blocks" in flax_key:
if flax_key in flax_state_dict:
new_tensor = flax_state_dict[flax_key]
else:
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)

validate_flax_state_dict(eval_shapes, flax_state_dict)
flax_state_dict = unflatten_dict(flax_state_dict)
del tensors
Expand All @@ -126,7 +138,12 @@ def load_fusionx_transformer(


def load_causvid_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
pretrained_model_name_or_path: str,
eval_shapes: dict,
device: str,
hf_download: bool = True,
num_layers: int = 40,
scan_layers: bool = True,
):
device = jax.local_devices(backend=device)[0]
with jax.default_device(device):
Expand All @@ -150,24 +167,9 @@ def load_causvid_transformer(
renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key)

pt_tuple_key = tuple(renamed_pt_key.split("."))

if "blocks" in pt_tuple_key:
new_key = ("blocks",) + pt_tuple_key[2:]
block_index = int(pt_tuple_key[1])
pt_tuple_key = new_key
flax_key, flax_tensor = rename_key_and_reshape_tensor(
pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL
)
flax_key = rename_for_nnx(flax_key)
flax_key = _tuple_str_to_int(flax_key)

if "blocks" in flax_key:
if flax_key in flax_state_dict:
new_tensor = flax_state_dict[flax_key]
else:
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)

validate_flax_state_dict(eval_shapes, flax_state_dict)
flax_state_dict = unflatten_dict(flax_state_dict)
del tensors
Expand All @@ -176,19 +178,31 @@ def load_causvid_transformer(


def load_wan_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
pretrained_model_name_or_path: str,
eval_shapes: dict,
device: str,
hf_download: bool = True,
num_layers: int = 40,
scan_layers: bool = True,
):

if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers)
elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers)
else:
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
return load_base_wan_transformer(
pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers
)


def load_base_wan_transformer(
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
pretrained_model_name_or_path: str,
eval_shapes: dict,
device: str,
hf_download: bool = True,
num_layers: int = 40,
scan_layers: bool = True,
):
device = jax.local_devices(backend=device)[0]
subfolder = "transformer"
Expand Down Expand Up @@ -247,24 +261,9 @@ def load_base_wan_transformer(
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
pt_tuple_key = tuple(renamed_pt_key.split("."))

if "blocks" in pt_tuple_key:
new_key = ("blocks",) + pt_tuple_key[2:]
block_index = int(pt_tuple_key[1])
pt_tuple_key = new_key
flax_key, flax_tensor = rename_key_and_reshape_tensor(
pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL
)
flax_key = rename_for_nnx(flax_key)
flax_key = _tuple_str_to_int(flax_key)

if "blocks" in flax_key:
if flax_key in flax_state_dict:
new_tensor = flax_state_dict[flax_key]
else:
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)

validate_flax_state_dict(eval_shapes, flax_state_dict)
flax_state_dict = unflatten_dict(flax_state_dict)
del tensors
Expand Down
Loading
Loading