Skip to content

Commit 6e8d610

Browse files
committed
Fix mesh axis bugs & Fix Wan-VACE transformer and pipeline to align with other Wan models and pipeline
* Remove an unnecessary `vae_spatial_axis_name` assignment to `vae_mesh` in `wan_pipeline.py`, that fixes `AttributeError: cannot assign to field 'vae_spatial_axis_name'`. * Adopts the shared `_create_common_components` method (used by `WanPipeline2_1`, `WanPipeline2_2`, etc.) into `_load_and_init` of `VaceWanPipeline2_1`. This align initialization and resolves mesh axis handling issues during `VaceWanPipeline2_1` initialization introduced by PR #359. * Adopts new parameters `mask_padding_tokens`, `enable_jax_named_scopes`, `use_base2_exp`, `use_experimental_scheduler` into `transformer_wan_vace.py` and `wan_vace_pipeline_2_1.py` (following `transformer_wan.py` and `wan_pipeline.py` respectively) that resolves attention errors (introduced by PR #359) and align Wan-VACE transformer to Wan transformer. * Removes the `load_common_components` argument from `from_pretrained` and `from_checkpoint` in `VaeWanPipeline2_1` for interface consistency.
1 parent c98002f commit 6e8d610

3 files changed

Lines changed: 161 additions & 124 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py

Lines changed: 132 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
limitations under the License.
1414
"""
1515

16+
import contextlib
1617
import math
1718
from typing import Any, Dict, Optional, Tuple
1819

@@ -62,8 +63,12 @@ def __init__(
6263
precision: jax.lax.Precision | None = None,
6364
attention: str = "dot_product",
6465
dropout: float = 0.0,
66+
mask_padding_tokens: bool = True,
67+
enable_jax_named_scopes: bool = False,
6568
apply_input_projection: bool = False,
6669
apply_output_projection: bool = False,
70+
use_base2_exp: bool = False,
71+
use_experimental_scheduler: bool = False,
6772
):
6873
"""Sets up the model.
6974
@@ -90,7 +95,7 @@ def __init__(
9095
apply_output_projection: Whether to apply an output projection before
9196
outputting the result.
9297
"""
93-
98+
self.enable_jax_named_scopes = enable_jax_named_scopes
9499
self.apply_input_projection = apply_input_projection
95100
self.apply_output_projection = apply_output_projection
96101

@@ -124,7 +129,12 @@ def __init__(
124129
precision=precision,
125130
attention_kernel=attention,
126131
dropout=dropout,
132+
is_self_attention=True,
133+
mask_padding_tokens=mask_padding_tokens,
127134
residual_checkpoint_name="self_attn",
135+
enable_jax_named_scopes=enable_jax_named_scopes,
136+
use_base2_exp=use_base2_exp,
137+
use_experimental_scheduler=use_experimental_scheduler,
128138
)
129139

130140
# 3. Cross-attention
@@ -143,7 +153,12 @@ def __init__(
143153
precision=precision,
144154
attention_kernel=attention,
145155
dropout=dropout,
156+
is_self_attention=False,
157+
mask_padding_tokens=mask_padding_tokens,
146158
residual_checkpoint_name="cross_attn",
159+
enable_jax_named_scopes=enable_jax_named_scopes,
160+
use_base2_exp=use_base2_exp,
161+
use_experimental_scheduler=use_experimental_scheduler,
147162
)
148163
assert cross_attn_norm is True, "cross_attn_norm must be True"
149164
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -158,6 +173,7 @@ def __init__(
158173
weights_dtype=weights_dtype,
159174
precision=precision,
160175
dropout=dropout,
176+
enable_jax_named_scopes=enable_jax_named_scopes,
161177
)
162178

163179
self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
@@ -180,6 +196,10 @@ def __init__(
180196
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
181197
)
182198

199+
def conditional_named_scope(self, name: str):
200+
"""Return a JAX named scope if enabled, otherwise a null context."""
201+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
202+
183203
def __call__(
184204
self,
185205
*,
@@ -191,65 +211,76 @@ def __call__(
191211
deterministic: bool = True,
192212
rngs: nnx.Rngs | None = None,
193213
) -> Tuple[jax.Array, jax.Array]:
194-
if self.apply_input_projection:
195-
control_hidden_states = self.proj_in(control_hidden_states)
196-
control_hidden_states = control_hidden_states + hidden_states
197214

198-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
199-
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
200-
)
201-
202-
control_hidden_states = jax.lax.with_sharding_constraint(
203-
control_hidden_states,
204-
PartitionSpec("data", "fsdp", "tensor"),
205-
)
206-
control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states")
207-
encoder_hidden_states = jax.lax.with_sharding_constraint(
208-
encoder_hidden_states,
209-
PartitionSpec("data", "fsdp", None),
210-
)
211-
212-
# 1. Self-attention
213-
with jax.named_scope("attn1"):
214-
norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
215-
control_hidden_states.dtype
216-
)
217-
attn_output = self.attn1(
218-
hidden_states=norm_hidden_states,
219-
encoder_hidden_states=norm_hidden_states,
220-
rotary_emb=rotary_emb,
221-
deterministic=deterministic,
222-
rngs=rngs,
215+
with self.conditional_named_scope("vace_transformer_block"):
216+
with self.conditional_named_scope("input_projection"):
217+
if self.apply_input_projection:
218+
control_hidden_states = self.proj_in(control_hidden_states)
219+
control_hidden_states = control_hidden_states + hidden_states
220+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
221+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)),
222+
6,
223+
axis=1,
223224
)
224-
control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(
225-
control_hidden_states.dtype
226-
)
227-
228-
# 2. Cross-attention
229-
with jax.named_scope("attn2"):
230-
norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype)
231-
attn_output = self.attn2(
232-
hidden_states=norm_hidden_states,
233-
encoder_hidden_states=encoder_hidden_states,
234-
deterministic=deterministic,
235-
rngs=rngs,
236-
)
237-
control_hidden_states = control_hidden_states + attn_output
238225

239-
# 3. Feed-forward
240-
with jax.named_scope("ffn"):
241-
norm_hidden_states = (self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
242-
control_hidden_states.dtype
243-
)
244-
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
245-
control_hidden_states = (
246-
control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa
247-
).astype(control_hidden_states.dtype)
248-
conditioning_states = None
249-
if self.apply_output_projection:
250-
conditioning_states = self.proj_out(control_hidden_states)
226+
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads"))
227+
control_hidden_states = jax.lax.with_sharding_constraint(control_hidden_states, axis_names)
228+
control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states")
229+
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv"))
230+
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
231+
232+
# 1. Self-attention
233+
with self.conditional_named_scope("self_attn"):
234+
with self.conditional_named_scope("self_attn_norm"):
235+
norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
236+
control_hidden_states.dtype
237+
)
238+
with self.conditional_named_scope("self_attn_attn"):
239+
attn_output = self.attn1(
240+
hidden_states=norm_hidden_states,
241+
encoder_hidden_states=norm_hidden_states,
242+
rotary_emb=rotary_emb,
243+
deterministic=deterministic,
244+
rngs=rngs,
245+
)
246+
with self.conditional_named_scope("self_attn_residual"):
247+
control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(
248+
control_hidden_states.dtype
249+
)
251250

252-
return conditioning_states, control_hidden_states
251+
# 2. Cross-attention
252+
with self.conditional_named_scope("cross_attn"):
253+
with self.conditional_named_scope("cross_attn_norm"):
254+
norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype)
255+
with self.conditional_named_scope("cross_attn_attn"):
256+
attn_output = self.attn2(
257+
hidden_states=norm_hidden_states,
258+
encoder_hidden_states=encoder_hidden_states,
259+
deterministic=deterministic,
260+
rngs=rngs,
261+
)
262+
with self.conditional_named_scope("cross_attn_residual"):
263+
control_hidden_states = control_hidden_states + attn_output
264+
265+
# 3. Feed-forward
266+
with self.conditional_named_scope("mlp"):
267+
with self.conditional_named_scope("mlp_norm"):
268+
norm_hidden_states = (
269+
self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa
270+
).astype(control_hidden_states.dtype)
271+
with self.conditional_named_scope("mlp_ffn"):
272+
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
273+
with self.conditional_named_scope("mlp_residual"):
274+
control_hidden_states = (
275+
control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa
276+
).astype(control_hidden_states.dtype)
277+
278+
with self.conditional_named_scope("output_projection"):
279+
conditioning_states = None
280+
if self.apply_output_projection:
281+
conditioning_states = self.proj_out(control_hidden_states)
282+
283+
return conditioning_states, control_hidden_states
253284

254285

255286
class WanVACEModel(WanModel):
@@ -289,7 +320,11 @@ def __init__(
289320
remat_policy: str = "None",
290321
names_which_can_be_saved: list[str] = [],
291322
names_which_can_be_offloaded: list[str] = [],
323+
mask_padding_tokens: bool = True,
292324
scan_layers: bool = True,
325+
enable_jax_named_scopes: bool = False,
326+
use_base2_exp: bool = False,
327+
use_experimental_scheduler: bool = False,
293328
):
294329
"""Initializes the VACE model.
295330
@@ -302,6 +337,7 @@ def __init__(
302337
out_channels = out_channels or in_channels
303338
self.num_layers = num_layers
304339
self.scan_layers = scan_layers
340+
self.enable_jax_named_scopes = enable_jax_named_scopes
305341

306342
# 1. Patch & position embedding
307343
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
@@ -329,6 +365,7 @@ def __init__(
329365
text_embed_dim=text_dim,
330366
image_embed_dim=image_dim,
331367
pos_embed_seq_len=pos_embed_seq_len,
368+
flash_min_seq_length=flash_min_seq_length,
332369
)
333370

334371
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
@@ -358,6 +395,10 @@ def __init__(
358395
precision=precision,
359396
attention=attention,
360397
dropout=dropout,
398+
mask_padding_tokens=mask_padding_tokens,
399+
enable_jax_named_scopes=enable_jax_named_scopes,
400+
use_base2_exp=use_base2_exp,
401+
use_experimental_scheduler=use_experimental_scheduler,
361402
)
362403
blocks.append(block)
363404
self.blocks = blocks
@@ -384,8 +425,12 @@ def __init__(
384425
precision=precision,
385426
attention=attention,
386427
dropout=dropout,
428+
mask_padding_tokens=mask_padding_tokens,
429+
enable_jax_named_scopes=enable_jax_named_scopes,
387430
apply_input_projection=vace_block_id == 0,
388431
apply_output_projection=True,
432+
use_base2_exp=use_base2_exp,
433+
use_experimental_scheduler=use_experimental_scheduler,
389434
)
390435
vace_blocks.append(vace_block)
391436
self.vace_blocks = vace_blocks
@@ -421,6 +466,10 @@ def __init__(
421466
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")),
422467
)
423468

469+
def conditional_named_scope(self, name: str):
470+
"""Return a JAX named scope if enabled, otherwise a null context."""
471+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
472+
424473
@jax.named_scope("WanVACEModel")
425474
def __call__(
426475
self,
@@ -436,7 +485,7 @@ def __call__(
436485
rngs: nnx.Rngs = None,
437486
) -> jax.Array:
438487
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
439-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
488+
batch_size, _, num_frames, height, width = hidden_states.shape
440489
p_t, p_h, p_w = self.config.patch_size
441490
post_patch_num_frames = num_frames // p_t
442491
post_patch_height = height // p_h
@@ -453,32 +502,36 @@ def __call__(
453502

454503
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
455504
control_hidden_states = jnp.transpose(control_hidden_states, (0, 2, 3, 4, 1))
456-
rotary_emb = self.rope(hidden_states)
457-
458-
hidden_states = self.patch_embedding(hidden_states)
459-
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
460-
461-
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
462-
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
463-
control_hidden_states_padding = jnp.zeros((
464-
batch_size,
465-
control_hidden_states.shape[1],
466-
hidden_states.shape[2] - control_hidden_states.shape[2],
467-
))
505+
with self.conditional_named_scope("rotary_embedding"):
506+
rotary_emb = self.rope(hidden_states)
507+
with self.conditional_named_scope("patch_embedding"):
508+
hidden_states = self.patch_embedding(hidden_states)
509+
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
510+
511+
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
512+
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
513+
control_hidden_states_padding = jnp.zeros(
514+
(
515+
batch_size,
516+
control_hidden_states.shape[1],
517+
hidden_states.shape[2] - control_hidden_states.shape[2],
518+
)
519+
)
468520

469521
control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2)
470522

471523
# Condition embedder is a FC layer.
472-
(
473-
temb,
474-
timestep_proj,
475-
encoder_hidden_states,
476-
encoder_hidden_states_image,
477-
_,
478-
) = self.condition_embedder( # We will need to mask out the text embedding.
479-
timestep, encoder_hidden_states, encoder_hidden_states_image
480-
)
481-
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
524+
with self.conditional_named_scope("condition_embedder"):
525+
(
526+
temb,
527+
timestep_proj,
528+
encoder_hidden_states,
529+
encoder_hidden_states_image,
530+
_,
531+
) = self.condition_embedder( # We will need to mask out the text embedding.
532+
timestep, encoder_hidden_states, encoder_hidden_states_image
533+
)
534+
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
482535

483536
if encoder_hidden_states_image is not None:
484537
raise NotImplementedError("img2vid is not yet implemented.")

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,6 @@ def _create_common_components(cls, config, vae_only=False, i2v=False):
630630
vae_devices_array = flat_devices.reshape(total_devices // vae_spatial, vae_spatial)
631631

632632
vae_mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial"))
633-
vae_mesh.vae_spatial_axis_name = "vae_spatial"
634633
max_logging.log(
635634
f"Created VAE specific mesh with axes ('redundant', 'vae_spatial') to support spatial sharding of {vae_spatial}."
636635
)

0 commit comments

Comments
 (0)