Skip to content

Commit 08834ab

Browse files
committed
adds option to use scan, linting.
1 parent fb1c00b commit 08834ab

15 files changed

Lines changed: 217 additions & 155 deletions

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ replicate_vae: False
4848
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
4949
# at the cost of time.
5050
precision: "DEFAULT"
51+
# Use jax.lax.scan for transformer layers
52+
scan_layers: True
5153

5254
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
5355
# It must be True for multi-host.

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
AUTOTUNE = tf.data.AUTOTUNE
2525
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2626

27+
2728
def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count):
2829
dataset = dataset.with_format("tensorflow")[:]
2930
tf_dataset = tf.data.Dataset.from_tensor_slices(dataset)

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
AUTOTUNE = tf.data.experimental.AUTOTUNE
4343
os.environ["TOKENIZERS_PARALLELISM"] = "false"
4444

45+
4546
def make_data_iterator(
4647
config,
4748
dataloading_host_index,

src/maxdiffusion/models/gradient_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def apply(
8585
names_which_can_be_saved: list = [],
8686
names_which_can_be_offloaded: list = [],
8787
static_argnums=(),
88+
prevent_cse: bool = False,
8889
) -> nnx.Module:
8990
"""
9091
Applies a gradient checkpoint policy to a module
@@ -99,4 +100,4 @@ def apply(
99100
policy = self.to_jax_policy(names_which_can_be_saved, names_which_can_be_offloaded)
100101
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
101102
return module
102-
return nnx.remat(module, prevent_cse=False, policy=policy, static_argnums=static_argnums) # pylint: disable=invalid-name
103+
return nnx.remat(module, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) # pylint: disable=invalid-name

src/maxdiffusion/models/modeling_flax_pytorch_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from chex import Array
2626
from ..utils import logging
2727
from .. import max_logging
28-
from .. import common_types
2928

3029

3130
logger = logging.get_logger(__name__)
@@ -87,7 +86,7 @@ def rename_key(key):
8786

8887
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
8988
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
90-
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict, model_type=None):
89+
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict, scan_layers=False):
9190
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
9291
# conv norm or layer norm
9392
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
@@ -112,12 +111,12 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic
112111
if isinstance(random_flax_state_dict[renamed_pt_tuple_key], Partitioned):
113112
# Wan 2.1 uses nnx.scan and nnx.vmap which stacks layer weights which will cause a shape mismatch
114113
# from the original weights which are not stacked.
115-
if model_type is not None and model_type == common_types.WAN_MODEL:
114+
if scan_layers:
116115
pass
117116
else:
118117
assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape
119118
else:
120-
if model_type is not None and model_type == common_types.WAN_MODEL:
119+
if scan_layers:
121120
pass
122121
else:
123122
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,12 @@ def __init__(
396396
remat_policy: str = "None",
397397
names_which_can_be_saved: list = [],
398398
names_which_can_be_offloaded: list = [],
399+
scan_layers: bool = True,
399400
):
400401
inner_dim = num_attention_heads * attention_head_dim
401402
out_channels = out_channels or in_channels
402403
self.num_layers = num_layers
404+
self.scan_layers = scan_layers
403405

404406
# 1. Patch & position embedding
405407
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
@@ -455,8 +457,29 @@ def init_block(rngs):
455457
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
456458
self.names_which_can_be_offloaded = names_which_can_be_offloaded
457459
self.names_which_can_be_saved = names_which_can_be_saved
458-
459-
self.blocks = init_block(rngs)
460+
if scan_layers:
461+
self.blocks = init_block(rngs)
462+
else:
463+
blocks = nnx.List([])
464+
for _ in range(num_layers):
465+
block = WanTransformerBlock(
466+
rngs=rngs,
467+
dim=inner_dim,
468+
ffn_dim=ffn_dim,
469+
num_heads=num_attention_heads,
470+
qk_norm=qk_norm,
471+
cross_attn_norm=cross_attn_norm,
472+
eps=eps,
473+
flash_min_seq_length=flash_min_seq_length,
474+
flash_block_sizes=flash_block_sizes,
475+
mesh=mesh,
476+
dtype=dtype,
477+
weights_dtype=weights_dtype,
478+
precision=precision,
479+
attention=attention,
480+
)
481+
blocks.append(block)
482+
self.blocks = blocks
460483

461484
self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False)
462485
self.proj_out = nnx.Linear(
@@ -505,24 +528,38 @@ def __call__(
505528
if encoder_hidden_states_image is not None:
506529
raise NotImplementedError("img2vid is not yet implemented.")
507530

508-
def scan_fn(carry, block):
509-
hidden_states_carry, rngs_carry = carry
510-
hidden_states = block(hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry)
511-
new_carry = (hidden_states, rngs_carry)
512-
return new_carry, None
531+
if self.scan_layers:
513532

514-
rematted_block_forward = self.gradient_checkpoint.apply(
515-
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded
516-
)
517-
initial_carry = (hidden_states, rngs)
518-
final_carry, _ = nnx.scan(
519-
rematted_block_forward,
520-
length=self.num_layers,
521-
in_axes=(nnx.Carry, 0),
522-
out_axes=(nnx.Carry, 0),
523-
)(initial_carry, self.blocks)
524-
525-
hidden_states, _ = final_carry
533+
def scan_fn(carry, block):
534+
hidden_states_carry, rngs_carry = carry
535+
hidden_states = block(
536+
hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry
537+
)
538+
new_carry = (hidden_states, rngs_carry)
539+
return new_carry, None
540+
541+
rematted_block_forward = self.gradient_checkpoint.apply(
542+
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
543+
)
544+
initial_carry = (hidden_states, rngs)
545+
final_carry, _ = nnx.scan(
546+
rematted_block_forward,
547+
length=self.num_layers,
548+
in_axes=(nnx.Carry, 0),
549+
out_axes=(nnx.Carry, 0),
550+
)(initial_carry, self.blocks)
551+
552+
hidden_states, _ = final_carry
553+
else:
554+
for block in self.blocks:
555+
556+
def layer_forward(hidden_states):
557+
return block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs)
558+
559+
rematted_layer_forward = self.gradient_checkpoint.apply(
560+
layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
561+
)
562+
hidden_states = rematted_layer_forward(hidden_states)
526563

527564
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
528565

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 57 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from safetensors import safe_open
2525
from flax.traverse_util import unflatten_dict, flatten_dict
2626
from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict)
27-
from ...common_types import WAN_MODEL
2827

2928
CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
3029
WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX"
@@ -73,8 +72,35 @@ def rename_for_custom_trasformer(key):
7372
return renamed_pt_key
7473

7574

75+
def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers):
76+
if scan_layers:
77+
if "blocks" in pt_tuple_key:
78+
new_key = ("blocks",) + pt_tuple_key[2:]
79+
block_index = int(pt_tuple_key[1])
80+
pt_tuple_key = new_key
81+
82+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
83+
84+
flax_key = rename_for_nnx(flax_key)
85+
flax_key = _tuple_str_to_int(flax_key)
86+
87+
if scan_layers:
88+
if "blocks" in flax_key:
89+
if flax_key in flax_state_dict:
90+
new_tensor = flax_state_dict[flax_key]
91+
else:
92+
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
93+
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
94+
return flax_key, flax_tensor
95+
96+
7697
def load_fusionx_transformer(
77-
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
98+
pretrained_model_name_or_path: str,
99+
eval_shapes: dict,
100+
device: str,
101+
hf_download: bool = True,
102+
num_layers: int = 40,
103+
scan_layers: bool = True,
78104
):
79105
device = jax.local_devices(backend=device)[0]
80106
with jax.default_device(device):
@@ -101,23 +127,9 @@ def load_fusionx_transformer(
101127

102128
pt_tuple_key = tuple(renamed_pt_key.split("."))
103129

104-
if "blocks" in pt_tuple_key:
105-
new_key = ("blocks",) + pt_tuple_key[2:]
106-
block_index = int(pt_tuple_key[1])
107-
pt_tuple_key = new_key
108-
flax_key, flax_tensor = rename_key_and_reshape_tensor(
109-
pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL
110-
)
111-
flax_key = rename_for_nnx(flax_key)
112-
flax_key = _tuple_str_to_int(flax_key)
113-
114-
if "blocks" in flax_key:
115-
if flax_key in flax_state_dict:
116-
new_tensor = flax_state_dict[flax_key]
117-
else:
118-
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
119-
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
130+
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
120131
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
132+
121133
validate_flax_state_dict(eval_shapes, flax_state_dict)
122134
flax_state_dict = unflatten_dict(flax_state_dict)
123135
del tensors
@@ -126,7 +138,12 @@ def load_fusionx_transformer(
126138

127139

128140
def load_causvid_transformer(
129-
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
141+
pretrained_model_name_or_path: str,
142+
eval_shapes: dict,
143+
device: str,
144+
hf_download: bool = True,
145+
num_layers: int = 40,
146+
scan_layers: bool = True,
130147
):
131148
device = jax.local_devices(backend=device)[0]
132149
with jax.default_device(device):
@@ -150,24 +167,9 @@ def load_causvid_transformer(
150167
renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key)
151168

152169
pt_tuple_key = tuple(renamed_pt_key.split("."))
153-
154-
if "blocks" in pt_tuple_key:
155-
new_key = ("blocks",) + pt_tuple_key[2:]
156-
block_index = int(pt_tuple_key[1])
157-
pt_tuple_key = new_key
158-
flax_key, flax_tensor = rename_key_and_reshape_tensor(
159-
pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL
160-
)
161-
flax_key = rename_for_nnx(flax_key)
162-
flax_key = _tuple_str_to_int(flax_key)
163-
164-
if "blocks" in flax_key:
165-
if flax_key in flax_state_dict:
166-
new_tensor = flax_state_dict[flax_key]
167-
else:
168-
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
169-
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
170+
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
170171
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
172+
171173
validate_flax_state_dict(eval_shapes, flax_state_dict)
172174
flax_state_dict = unflatten_dict(flax_state_dict)
173175
del tensors
@@ -176,19 +178,31 @@ def load_causvid_transformer(
176178

177179

178180
def load_wan_transformer(
179-
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
181+
pretrained_model_name_or_path: str,
182+
eval_shapes: dict,
183+
device: str,
184+
hf_download: bool = True,
185+
num_layers: int = 40,
186+
scan_layers: bool = True,
180187
):
181188

182189
if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
183-
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
190+
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers)
184191
elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
185-
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
192+
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers)
186193
else:
187-
return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers)
194+
return load_base_wan_transformer(
195+
pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers
196+
)
188197

189198

190199
def load_base_wan_transformer(
191-
pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40
200+
pretrained_model_name_or_path: str,
201+
eval_shapes: dict,
202+
device: str,
203+
hf_download: bool = True,
204+
num_layers: int = 40,
205+
scan_layers: bool = True,
192206
):
193207
device = jax.local_devices(backend=device)[0]
194208
subfolder = "transformer"
@@ -247,24 +261,9 @@ def load_base_wan_transformer(
247261
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
248262
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
249263
pt_tuple_key = tuple(renamed_pt_key.split("."))
250-
251-
if "blocks" in pt_tuple_key:
252-
new_key = ("blocks",) + pt_tuple_key[2:]
253-
block_index = int(pt_tuple_key[1])
254-
pt_tuple_key = new_key
255-
flax_key, flax_tensor = rename_key_and_reshape_tensor(
256-
pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL
257-
)
258-
flax_key = rename_for_nnx(flax_key)
259-
flax_key = _tuple_str_to_int(flax_key)
260-
261-
if "blocks" in flax_key:
262-
if flax_key in flax_state_dict:
263-
new_tensor = flax_state_dict[flax_key]
264-
else:
265-
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape)
266-
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
264+
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
267265
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
266+
268267
validate_flax_state_dict(eval_shapes, flax_state_dict)
269268
flax_state_dict = unflatten_dict(flax_state_dict)
270269
del tensors

0 commit comments

Comments
 (0)