Skip to content

Commit 10ade80

Browse files
committed
Fix mesh axis bugs & Fix Wan-VACE transformer and pipeline to align with other Wan models and pipeline
* Remove an unnecessary `vae_spatial_axis_name` assignment to `vae_mesh` in `wan_pipeline.py`, that fixes `AttributeError: cannot assign to field 'vae_spatial_axis_name'`. * Adopts the shared `_create_common_components` method (used by `WanPipeline2_1`, `WanPipeline2_2`, etc.) into `_load_and_init` of `VaceWanPipeline2_1`. This align initialization and resolves mesh axis handling issues during `VaceWanPipeline2_1` initialization introduced by PR #359. * Adopts new parameters `mask_padding_tokens`, `enable_jax_named_scopes`, `use_base2_exp`, `use_experimental_scheduler` into `transformer_wan_vace.py` and `wan_vace_pipeline_2_1.py` (following `transformer_wan.py` and `wan_pipeline.py` respectively) that resolves attention errors (introduced by PR #359) and align Wan-VACE transformer to Wan transformer. * Removes the `load_common_components` argument from `from_pretrained` and `from_checkpoint` in `VaeWanPipeline2_1` for interface consistency.
1 parent c98002f commit 10ade80

3 files changed

Lines changed: 152 additions & 120 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py

Lines changed: 124 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
limitations under the License.
1414
"""
1515

16+
import contextlib
1617
import math
1718
from typing import Any, Dict, Optional, Tuple
1819

@@ -62,8 +63,12 @@ def __init__(
6263
precision: jax.lax.Precision | None = None,
6364
attention: str = "dot_product",
6465
dropout: float = 0.0,
66+
mask_padding_tokens: bool = True,
67+
enable_jax_named_scopes: bool = False,
6568
apply_input_projection: bool = False,
6669
apply_output_projection: bool = False,
70+
use_base2_exp: bool = False,
71+
use_experimental_scheduler: bool = False,
6772
):
6873
"""Sets up the model.
6974
@@ -90,7 +95,7 @@ def __init__(
9095
apply_output_projection: Whether to apply an output projection before
9196
outputting the result.
9297
"""
93-
98+
self.enable_jax_named_scopes = enable_jax_named_scopes
9499
self.apply_input_projection = apply_input_projection
95100
self.apply_output_projection = apply_output_projection
96101

@@ -124,7 +129,12 @@ def __init__(
124129
precision=precision,
125130
attention_kernel=attention,
126131
dropout=dropout,
132+
is_self_attention=True,
133+
mask_padding_tokens=mask_padding_tokens,
127134
residual_checkpoint_name="self_attn",
135+
enable_jax_named_scopes=enable_jax_named_scopes,
136+
use_base2_exp=use_base2_exp,
137+
use_experimental_scheduler=use_experimental_scheduler,
128138
)
129139

130140
# 3. Cross-attention
@@ -143,7 +153,12 @@ def __init__(
143153
precision=precision,
144154
attention_kernel=attention,
145155
dropout=dropout,
156+
is_self_attention=False,
157+
mask_padding_tokens=mask_padding_tokens,
146158
residual_checkpoint_name="cross_attn",
159+
enable_jax_named_scopes=enable_jax_named_scopes,
160+
use_base2_exp=use_base2_exp,
161+
use_experimental_scheduler=use_experimental_scheduler,
147162
)
148163
assert cross_attn_norm is True, "cross_attn_norm must be True"
149164
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -158,6 +173,7 @@ def __init__(
158173
weights_dtype=weights_dtype,
159174
precision=precision,
160175
dropout=dropout,
176+
enable_jax_named_scopes=enable_jax_named_scopes,
161177
)
162178

163179
self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
@@ -180,6 +196,10 @@ def __init__(
180196
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
181197
)
182198

199+
def conditional_named_scope(self, name: str):
200+
"""Return a JAX named scope if enabled, otherwise a null context."""
201+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
202+
183203
def __call__(
184204
self,
185205
*,
@@ -191,65 +211,74 @@ def __call__(
191211
deterministic: bool = True,
192212
rngs: nnx.Rngs | None = None,
193213
) -> Tuple[jax.Array, jax.Array]:
194-
if self.apply_input_projection:
195-
control_hidden_states = self.proj_in(control_hidden_states)
196-
control_hidden_states = control_hidden_states + hidden_states
197-
198-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
199-
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
200-
)
201-
202-
control_hidden_states = jax.lax.with_sharding_constraint(
203-
control_hidden_states,
204-
PartitionSpec("data", "fsdp", "tensor"),
205-
)
206-
control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states")
207-
encoder_hidden_states = jax.lax.with_sharding_constraint(
208-
encoder_hidden_states,
209-
PartitionSpec("data", "fsdp", None),
210-
)
211-
212-
# 1. Self-attention
213-
with jax.named_scope("attn1"):
214-
norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
215-
control_hidden_states.dtype
214+
with self.conditional_named_scope("vace_transformer_block"):
215+
with self.conditional_named_scope("input_projection"):
216+
if self.apply_input_projection:
217+
control_hidden_states = self.proj_in(control_hidden_states)
218+
control_hidden_states = control_hidden_states + hidden_states
219+
220+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
221+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
216222
)
217-
attn_output = self.attn1(
218-
hidden_states=norm_hidden_states,
219-
encoder_hidden_states=norm_hidden_states,
220-
rotary_emb=rotary_emb,
221-
deterministic=deterministic,
222-
rngs=rngs,
223-
)
224-
control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(
225-
control_hidden_states.dtype
226-
)
227-
228-
# 2. Cross-attention
229-
with jax.named_scope("attn2"):
230-
norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype)
231-
attn_output = self.attn2(
232-
hidden_states=norm_hidden_states,
233-
encoder_hidden_states=encoder_hidden_states,
234-
deterministic=deterministic,
235-
rngs=rngs,
236-
)
237-
control_hidden_states = control_hidden_states + attn_output
238223

239-
# 3. Feed-forward
240-
with jax.named_scope("ffn"):
241-
norm_hidden_states = (self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
242-
control_hidden_states.dtype
243-
)
244-
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
245-
control_hidden_states = (
246-
control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa
247-
).astype(control_hidden_states.dtype)
248-
conditioning_states = None
249-
if self.apply_output_projection:
250-
conditioning_states = self.proj_out(control_hidden_states)
224+
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads"))
225+
control_hidden_states = jax.lax.with_sharding_constraint(control_hidden_states, axis_names)
226+
control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states")
227+
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv"))
228+
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
229+
230+
# 1. Self-attention
231+
with self.conditional_named_scope("self_attn"):
232+
with self.conditional_named_scope("self_attn_norm"):
233+
norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
234+
control_hidden_states.dtype
235+
)
236+
with self.conditional_named_scope("self_attn_attn"):
237+
attn_output = self.attn1(
238+
hidden_states=norm_hidden_states,
239+
encoder_hidden_states=norm_hidden_states,
240+
rotary_emb=rotary_emb,
241+
deterministic=deterministic,
242+
rngs=rngs,
243+
)
244+
with self.conditional_named_scope("self_attn_residual"):
245+
control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(
246+
control_hidden_states.dtype
247+
)
251248

252-
return conditioning_states, control_hidden_states
249+
# 2. Cross-attention
250+
with self.conditional_named_scope("cross_attn"):
251+
with self.conditional_named_scope("cross_attn_norm"):
252+
norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype)
253+
with self.conditional_named_scope("cross_attn_attn"):
254+
attn_output = self.attn2(
255+
hidden_states=norm_hidden_states,
256+
encoder_hidden_states=encoder_hidden_states,
257+
deterministic=deterministic,
258+
rngs=rngs,
259+
)
260+
with self.conditional_named_scope("cross_attn_residual"):
261+
control_hidden_states = control_hidden_states + attn_output
262+
263+
# 3. Feed-forward
264+
with self.conditional_named_scope("mlp"):
265+
with self.conditional_named_scope("mlp_norm"):
266+
norm_hidden_states = (
267+
self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa
268+
).astype(control_hidden_states.dtype)
269+
with self.conditional_named_scope("mlp_ffn"):
270+
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
271+
with self.conditional_named_scope("mlp_residual"):
272+
control_hidden_states = (
273+
control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa
274+
).astype(control_hidden_states.dtype)
275+
276+
with self.conditional_named_scope("output_projection"):
277+
conditioning_states = None
278+
if self.apply_output_projection:
279+
conditioning_states = self.proj_out(control_hidden_states)
280+
281+
return conditioning_states, control_hidden_states
253282

254283

255284
class WanVACEModel(WanModel):
@@ -289,7 +318,11 @@ def __init__(
289318
remat_policy: str = "None",
290319
names_which_can_be_saved: list[str] = [],
291320
names_which_can_be_offloaded: list[str] = [],
321+
mask_padding_tokens: bool = True,
292322
scan_layers: bool = True,
323+
enable_jax_named_scopes: bool = False,
324+
use_base2_exp: bool = False,
325+
use_experimental_scheduler: bool = False,
293326
):
294327
"""Initializes the VACE model.
295328
@@ -302,6 +335,7 @@ def __init__(
302335
out_channels = out_channels or in_channels
303336
self.num_layers = num_layers
304337
self.scan_layers = scan_layers
338+
self.enable_jax_named_scopes = enable_jax_named_scopes
305339

306340
# 1. Patch & position embedding
307341
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
@@ -329,6 +363,7 @@ def __init__(
329363
text_embed_dim=text_dim,
330364
image_embed_dim=image_dim,
331365
pos_embed_seq_len=pos_embed_seq_len,
366+
flash_min_seq_length=flash_min_seq_length,
332367
)
333368

334369
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
@@ -358,6 +393,10 @@ def __init__(
358393
precision=precision,
359394
attention=attention,
360395
dropout=dropout,
396+
mask_padding_tokens=mask_padding_tokens,
397+
enable_jax_named_scopes=enable_jax_named_scopes,
398+
use_base2_exp=use_base2_exp,
399+
use_experimental_scheduler=use_experimental_scheduler,
361400
)
362401
blocks.append(block)
363402
self.blocks = blocks
@@ -384,8 +423,12 @@ def __init__(
384423
precision=precision,
385424
attention=attention,
386425
dropout=dropout,
426+
mask_padding_tokens=mask_padding_tokens,
427+
enable_jax_named_scopes=enable_jax_named_scopes,
387428
apply_input_projection=vace_block_id == 0,
388429
apply_output_projection=True,
430+
use_base2_exp=use_base2_exp,
431+
use_experimental_scheduler=use_experimental_scheduler,
389432
)
390433
vace_blocks.append(vace_block)
391434
self.vace_blocks = vace_blocks
@@ -421,6 +464,10 @@ def __init__(
421464
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")),
422465
)
423466

467+
def conditional_named_scope(self, name: str):
468+
"""Return a JAX named scope if enabled, otherwise a null context."""
469+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
470+
424471
@jax.named_scope("WanVACEModel")
425472
def __call__(
426473
self,
@@ -436,7 +483,7 @@ def __call__(
436483
rngs: nnx.Rngs = None,
437484
) -> jax.Array:
438485
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
439-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
486+
batch_size, _, num_frames, height, width = hidden_states.shape
440487
p_t, p_h, p_w = self.config.patch_size
441488
post_patch_num_frames = num_frames // p_t
442489
post_patch_height = height // p_h
@@ -453,13 +500,14 @@ def __call__(
453500

454501
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
455502
control_hidden_states = jnp.transpose(control_hidden_states, (0, 2, 3, 4, 1))
456-
rotary_emb = self.rope(hidden_states)
457-
458-
hidden_states = self.patch_embedding(hidden_states)
459-
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
460-
461-
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
462-
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
503+
with self.conditional_named_scope("rotary_embedding"):
504+
rotary_emb = self.rope(hidden_states)
505+
with self.conditional_named_scope("patch_embedding"):
506+
hidden_states = self.patch_embedding(hidden_states)
507+
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
508+
509+
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
510+
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
463511
control_hidden_states_padding = jnp.zeros((
464512
batch_size,
465513
control_hidden_states.shape[1],
@@ -469,16 +517,17 @@ def __call__(
469517
control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2)
470518

471519
# Condition embedder is a FC layer.
472-
(
473-
temb,
474-
timestep_proj,
475-
encoder_hidden_states,
476-
encoder_hidden_states_image,
477-
_,
478-
) = self.condition_embedder( # We will need to mask out the text embedding.
479-
timestep, encoder_hidden_states, encoder_hidden_states_image
480-
)
481-
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
520+
with self.conditional_named_scope("condition_embedder"):
521+
(
522+
temb,
523+
timestep_proj,
524+
encoder_hidden_states,
525+
encoder_hidden_states_image,
526+
_,
527+
) = self.condition_embedder( # We will need to mask out the text embedding.
528+
timestep, encoder_hidden_states, encoder_hidden_states_image
529+
)
530+
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
482531

483532
if encoder_hidden_states_image is not None:
484533
raise NotImplementedError("img2vid is not yet implemented.")

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,6 @@ def _create_common_components(cls, config, vae_only=False, i2v=False):
630630
vae_devices_array = flat_devices.reshape(total_devices // vae_spatial, vae_spatial)
631631

632632
vae_mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial"))
633-
vae_mesh.vae_spatial_axis_name = "vae_spatial"
634633
max_logging.log(
635634
f"Created VAE specific mesh with axes ('redundant', 'vae_spatial') to support spatial sharding of {vae_spatial}."
636635
)

0 commit comments

Comments
 (0)