Skip to content

Commit 3d2edcc

Browse files
committed
use nnx.scan over for loop.
1 parent aa442f9 commit 3d2edcc

4 files changed

Lines changed: 49 additions & 25 deletions

File tree

src/maxdiffusion/common_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,5 @@
4343
KEEP_1 = "activation_keep_1"
4444
KEEP_2 = "activation_keep_2"
4545
CONV_OUT = "activation_conv_out_channels"
46+
47+
WAN_MODEL = "Wan2.1"

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from chex import Array
2626
from ..utils import logging
2727
from .. import max_logging
28+
from .. import common_types
2829

2930

3031
logger = logging.get_logger(__name__)
@@ -86,7 +87,7 @@ def rename_key(key):
8687

8788
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
8889
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
89-
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
90+
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict, model_type=None):
9091
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
9192
# conv norm or layer norm
9293
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
@@ -109,9 +110,17 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic
109110
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
110111
if renamed_pt_tuple_key in random_flax_state_dict:
111112
if isinstance(random_flax_state_dict[renamed_pt_tuple_key], Partitioned):
112-
assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape
113+
# Wan 2.1 uses nnx.scan and nnx.vmap which stacks layer weights which will cause a shape mismatch
114+
# from the original weights which are not stacked.
115+
if model_type is not None and model_type == common_types.WAN_MODEL:
116+
pass
117+
else:
118+
assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape
113119
else:
114-
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
120+
if model_type is not None and model_type == common_types.WAN_MODEL:
121+
pass
122+
else:
123+
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
115124
return renamed_pt_tuple_key, pt_tensor.T
116125

117126
if (

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def __init__(
359359
):
360360
inner_dim = num_attention_heads * attention_head_dim
361361
out_channels = out_channels or in_channels
362+
self.num_layers = num_layers
362363

363364
# 1. Patch & position embedding
364365
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
@@ -396,9 +397,10 @@ def __init__(
396397
)
397398

398399
# 3. Transformer blocks
399-
blocks = []
400-
for _ in range(num_layers):
401-
block = WanTransformerBlock(
400+
@nnx.split_rngs(splits=num_layers)
401+
@nnx.vmap
402+
def init_block(rngs):
403+
return WanTransformerBlock(
402404
rngs=rngs,
403405
dim=inner_dim,
404406
ffn_dim=ffn_dim,
@@ -414,8 +416,7 @@ def __init__(
414416
precision=precision,
415417
attention=attention,
416418
)
417-
blocks.append(block)
418-
self.blocks = blocks
419+
self.blocks = init_block(rngs)
419420

420421
self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False)
421422
self.proj_out = nnx.Linear(
@@ -463,21 +464,21 @@ def __call__(
463464
if encoder_hidden_states_image is not None:
464465
raise NotImplementedError("img2vid is not yet implemented.")
465466

466-
def skip_block_true(hidden_states):
467-
split_bs = hidden_states.shape[0] // 2
468-
prev_neg_hidden_states = hidden_states[split_bs:]
467+
def scan_fn(carry, block):
468+
hidden_states, encoder_hidden_states, timestep_proj, rotary_emb = carry
469469
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
470-
hidden_states = jnp.concatenate([hidden_states[:split_bs], prev_neg_hidden_states], axis=0)
471-
return hidden_states
472-
473-
for block_idx, block in enumerate(self.blocks):
474-
should_skip_block = slg_mask[block_idx] & is_uncond
475-
hidden_states = jax.lax.cond(
476-
should_skip_block,
477-
lambda _: skip_block_true(hidden_states), # If true, pass through original hidden_states (skip block)
478-
lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb),
479-
hidden_states,
480-
)
470+
return (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
471+
472+
initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
473+
final_carry = nnx.scan(
474+
scan_fn,
475+
length=self.num_layers,
476+
in_axes=(nnx.Carry, 0),
477+
out_axes=nnx.Carry,
478+
)(initial_carry, self.blocks)
479+
480+
hidden_states = final_carry[0]
481+
481482
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
482483

483484
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).astype(hidden_states.dtype)

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from safetensors import safe_open
99
from flax.traverse_util import unflatten_dict, flatten_dict
1010
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
11+
from ...common_types import WAN_MODEL
1112

1213
CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
1314
WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX"
@@ -82,7 +83,7 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
8283

8384
pt_tuple_key = tuple(renamed_pt_key.split("."))
8485

85-
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
86+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL)
8687
flax_key = rename_for_nnx(flax_key)
8788
flax_key = _tuple_str_to_int(flax_key)
8889
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
@@ -117,7 +118,7 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
117118

118119
pt_tuple_key = tuple(renamed_pt_key.split("."))
119120

120-
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
121+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL)
121122
flax_key = rename_for_nnx(flax_key)
122123
flax_key = _tuple_str_to_int(flax_key)
123124
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
@@ -196,9 +197,20 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
196197
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
197198
pt_tuple_key = tuple(renamed_pt_key.split("."))
198199

199-
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
200+
if "blocks" in pt_tuple_key:
201+
new_key = ("blocks",) + pt_tuple_key[2:]
202+
block_index = int(pt_tuple_key[1])
203+
pt_tuple_key = new_key
204+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL)
200205
flax_key = rename_for_nnx(flax_key)
201206
flax_key = _tuple_str_to_int(flax_key)
207+
208+
if "blocks" in flax_key:
209+
if flax_key in flax_state_dict:
210+
new_tensor = flax_state_dict[flax_key]
211+
else:
212+
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
213+
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
202214
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
203215
validate_flax_state_dict(eval_shapes, flax_state_dict)
204216
flax_state_dict = unflatten_dict(flax_state_dict)

0 commit comments

Comments
 (0)