Skip to content

Commit b1b0a18

Browse files
committed
Fix mesh axis bugs & Fix Wan-VACE transformer and pipeline to align with other Wan models and pipeline
* Remove an unnecessary `vae_spatial_axis_name` assignment to `vae_mesh` in `wan_pipeline.py`, that fixes `AttributeError: cannot assign to field 'vae_spatial_axis_name'`. * Adopts the shared `_create_common_components` method (used by `WanPipeline2_1`, `WanPipeline2_2`, etc.) into `_load_and_init` of `VaceWanPipeline2_1`. This align initialization and resolves mesh axis handling issues during `VaceWanPipeline2_1` initialization introduced by PR #359. * Adopts new parameters `mask_padding_tokens`, `enable_jax_named_scopes`, `use_base2_exp`, `use_experimental_scheduler` into `transformer_wan_vace.py` and `wan_vace_pipeline_2_1.py` (following `transformer_wan.py` and `wan_pipeline.py` respectively) that resolves attention errors (introduced by PR #359) and align Wan-VACE transformer to Wan transformer. * Removes the `load_common_components` argument from `from_pretrained` and `from_checkpoint` in `VaeWanPipeline2_1` for interface consistency.
1 parent c98002f commit b1b0a18

3 files changed

Lines changed: 152 additions & 118 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py

Lines changed: 123 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
limitations under the License.
1414
"""
1515

16+
import contextlib
1617
import math
1718
from typing import Any, Dict, Optional, Tuple
1819

@@ -62,8 +63,12 @@ def __init__(
6263
precision: jax.lax.Precision | None = None,
6364
attention: str = "dot_product",
6465
dropout: float = 0.0,
66+
mask_padding_tokens: bool = True,
67+
enable_jax_named_scopes: bool = False,
6568
apply_input_projection: bool = False,
6669
apply_output_projection: bool = False,
70+
use_base2_exp: bool = False,
71+
use_experimental_scheduler: bool = False,
6772
):
6873
"""Sets up the model.
6974
@@ -90,7 +95,7 @@ def __init__(
9095
apply_output_projection: Whether to apply an output projection before
9196
outputting the result.
9297
"""
93-
98+
self.enable_jax_named_scopes = enable_jax_named_scopes
9499
self.apply_input_projection = apply_input_projection
95100
self.apply_output_projection = apply_output_projection
96101

@@ -124,7 +129,12 @@ def __init__(
124129
precision=precision,
125130
attention_kernel=attention,
126131
dropout=dropout,
132+
is_self_attention=True,
133+
mask_padding_tokens=mask_padding_tokens,
127134
residual_checkpoint_name="self_attn",
135+
enable_jax_named_scopes=enable_jax_named_scopes,
136+
use_base2_exp=use_base2_exp,
137+
use_experimental_scheduler=use_experimental_scheduler,
128138
)
129139

130140
# 3. Cross-attention
@@ -143,7 +153,12 @@ def __init__(
143153
precision=precision,
144154
attention_kernel=attention,
145155
dropout=dropout,
156+
is_self_attention=False,
157+
mask_padding_tokens=mask_padding_tokens,
146158
residual_checkpoint_name="cross_attn",
159+
enable_jax_named_scopes=enable_jax_named_scopes,
160+
use_base2_exp=use_base2_exp,
161+
use_experimental_scheduler=use_experimental_scheduler,
147162
)
148163
assert cross_attn_norm is True, "cross_attn_norm must be True"
149164
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -158,6 +173,7 @@ def __init__(
158173
weights_dtype=weights_dtype,
159174
precision=precision,
160175
dropout=dropout,
176+
enable_jax_named_scopes=enable_jax_named_scopes,
161177
)
162178

163179
self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
@@ -180,6 +196,10 @@ def __init__(
180196
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
181197
)
182198

199+
def conditional_named_scope(self, name: str):
200+
"""Return a JAX named scope if enabled, otherwise a null context."""
201+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
202+
183203
def __call__(
184204
self,
185205
*,
@@ -191,65 +211,75 @@ def __call__(
191211
deterministic: bool = True,
192212
rngs: nnx.Rngs | None = None,
193213
) -> Tuple[jax.Array, jax.Array]:
194-
if self.apply_input_projection:
195-
control_hidden_states = self.proj_in(control_hidden_states)
196-
control_hidden_states = control_hidden_states + hidden_states
197-
198-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
199-
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
200-
)
201214

202-
control_hidden_states = jax.lax.with_sharding_constraint(
203-
control_hidden_states,
204-
PartitionSpec("data", "fsdp", "tensor"),
205-
)
206-
control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states")
207-
encoder_hidden_states = jax.lax.with_sharding_constraint(
208-
encoder_hidden_states,
209-
PartitionSpec("data", "fsdp", None),
210-
)
215+
with self.conditional_named_scope("vace_transformer_block"):
216+
with self.conditional_named_scope("input_projection"):
217+
if self.apply_input_projection:
218+
control_hidden_states = self.proj_in(control_hidden_states)
219+
control_hidden_states = control_hidden_states + hidden_states
211220

212-
# 1. Self-attention
213-
with jax.named_scope("attn1"):
214-
norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
215-
control_hidden_states.dtype
216-
)
217-
attn_output = self.attn1(
218-
hidden_states=norm_hidden_states,
219-
encoder_hidden_states=norm_hidden_states,
220-
rotary_emb=rotary_emb,
221-
deterministic=deterministic,
222-
rngs=rngs,
221+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
222+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
223223
)
224-
control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(
225-
control_hidden_states.dtype
226-
)
227-
228-
# 2. Cross-attention
229-
with jax.named_scope("attn2"):
230-
norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype)
231-
attn_output = self.attn2(
232-
hidden_states=norm_hidden_states,
233-
encoder_hidden_states=encoder_hidden_states,
234-
deterministic=deterministic,
235-
rngs=rngs,
236-
)
237-
control_hidden_states = control_hidden_states + attn_output
238224

239-
# 3. Feed-forward
240-
with jax.named_scope("ffn"):
241-
norm_hidden_states = (self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
242-
control_hidden_states.dtype
243-
)
244-
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
245-
control_hidden_states = (
246-
control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa
247-
).astype(control_hidden_states.dtype)
248-
conditioning_states = None
249-
if self.apply_output_projection:
250-
conditioning_states = self.proj_out(control_hidden_states)
225+
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads"))
226+
control_hidden_states = jax.lax.with_sharding_constraint(control_hidden_states, axis_names)
227+
control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states")
228+
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv"))
229+
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
230+
231+
# 1. Self-attention
232+
with self.conditional_named_scope("self_attn"):
233+
with self.conditional_named_scope("self_attn_norm"):
234+
norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
235+
control_hidden_states.dtype
236+
)
237+
with self.conditional_named_scope("self_attn_attn"):
238+
attn_output = self.attn1(
239+
hidden_states=norm_hidden_states,
240+
encoder_hidden_states=norm_hidden_states,
241+
rotary_emb=rotary_emb,
242+
deterministic=deterministic,
243+
rngs=rngs,
244+
)
245+
with self.conditional_named_scope("self_attn_residual"):
246+
control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(
247+
control_hidden_states.dtype
248+
)
251249

252-
return conditioning_states, control_hidden_states
250+
# 2. Cross-attention
251+
with self.conditional_named_scope("cross_attn"):
252+
with self.conditional_named_scope("cross_attn_norm"):
253+
norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype)
254+
with self.conditional_named_scope("cross_attn_attn"):
255+
attn_output = self.attn2(
256+
hidden_states=norm_hidden_states,
257+
encoder_hidden_states=encoder_hidden_states,
258+
deterministic=deterministic,
259+
rngs=rngs,
260+
)
261+
with self.conditional_named_scope("cross_attn_residual"):
262+
control_hidden_states = control_hidden_states + attn_output
263+
264+
# 3. Feed-forward
265+
with self.conditional_named_scope("mlp"):
266+
with self.conditional_named_scope("mlp_norm"):
267+
norm_hidden_states = (
268+
self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa
269+
).astype(control_hidden_states.dtype)
270+
with self.conditional_named_scope("mlp_ffn"):
271+
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
272+
with self.conditional_named_scope("mlp_residual"):
273+
control_hidden_states = (
274+
control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa
275+
).astype(control_hidden_states.dtype)
276+
277+
with self.conditional_named_scope("output_projection"):
278+
conditioning_states = None
279+
if self.apply_output_projection:
280+
conditioning_states = self.proj_out(control_hidden_states)
281+
282+
return conditioning_states, control_hidden_states
253283

254284

255285
class WanVACEModel(WanModel):
@@ -289,7 +319,11 @@ def __init__(
289319
remat_policy: str = "None",
290320
names_which_can_be_saved: list[str] = [],
291321
names_which_can_be_offloaded: list[str] = [],
322+
mask_padding_tokens: bool = True,
292323
scan_layers: bool = True,
324+
enable_jax_named_scopes: bool = False,
325+
use_base2_exp: bool = False,
326+
use_experimental_scheduler: bool = False,
293327
):
294328
"""Initializes the VACE model.
295329
@@ -302,6 +336,7 @@ def __init__(
302336
out_channels = out_channels or in_channels
303337
self.num_layers = num_layers
304338
self.scan_layers = scan_layers
339+
self.enable_jax_named_scopes = enable_jax_named_scopes
305340

306341
# 1. Patch & position embedding
307342
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
@@ -329,6 +364,7 @@ def __init__(
329364
text_embed_dim=text_dim,
330365
image_embed_dim=image_dim,
331366
pos_embed_seq_len=pos_embed_seq_len,
367+
flash_min_seq_length=flash_min_seq_length,
332368
)
333369

334370
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
@@ -358,6 +394,10 @@ def __init__(
358394
precision=precision,
359395
attention=attention,
360396
dropout=dropout,
397+
mask_padding_tokens=mask_padding_tokens,
398+
enable_jax_named_scopes=enable_jax_named_scopes,
399+
use_base2_exp=use_base2_exp,
400+
use_experimental_scheduler=use_experimental_scheduler,
361401
)
362402
blocks.append(block)
363403
self.blocks = blocks
@@ -384,8 +424,12 @@ def __init__(
384424
precision=precision,
385425
attention=attention,
386426
dropout=dropout,
427+
mask_padding_tokens=mask_padding_tokens,
428+
enable_jax_named_scopes=enable_jax_named_scopes,
387429
apply_input_projection=vace_block_id == 0,
388430
apply_output_projection=True,
431+
use_base2_exp=use_base2_exp,
432+
use_experimental_scheduler=use_experimental_scheduler,
389433
)
390434
vace_blocks.append(vace_block)
391435
self.vace_blocks = vace_blocks
@@ -421,6 +465,10 @@ def __init__(
421465
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")),
422466
)
423467

468+
def conditional_named_scope(self, name: str):
469+
"""Return a JAX named scope if enabled, otherwise a null context."""
470+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
471+
424472
@jax.named_scope("WanVACEModel")
425473
def __call__(
426474
self,
@@ -436,7 +484,7 @@ def __call__(
436484
rngs: nnx.Rngs = None,
437485
) -> jax.Array:
438486
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
439-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
487+
batch_size, _, num_frames, height, width = hidden_states.shape
440488
p_t, p_h, p_w = self.config.patch_size
441489
post_patch_num_frames = num_frames // p_t
442490
post_patch_height = height // p_h
@@ -453,13 +501,14 @@ def __call__(
453501

454502
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
455503
control_hidden_states = jnp.transpose(control_hidden_states, (0, 2, 3, 4, 1))
456-
rotary_emb = self.rope(hidden_states)
457-
458-
hidden_states = self.patch_embedding(hidden_states)
459-
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
460-
461-
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
462-
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
504+
with self.conditional_named_scope("rotary_embedding"):
505+
rotary_emb = self.rope(hidden_states)
506+
with self.conditional_named_scope("patch_embedding"):
507+
hidden_states = self.patch_embedding(hidden_states)
508+
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
509+
510+
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
511+
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
463512
control_hidden_states_padding = jnp.zeros((
464513
batch_size,
465514
control_hidden_states.shape[1],
@@ -469,16 +518,17 @@ def __call__(
469518
control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2)
470519

471520
# Condition embedder is a FC layer.
472-
(
473-
temb,
474-
timestep_proj,
475-
encoder_hidden_states,
476-
encoder_hidden_states_image,
477-
_,
478-
) = self.condition_embedder( # We will need to mask out the text embedding.
479-
timestep, encoder_hidden_states, encoder_hidden_states_image
480-
)
481-
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
521+
with self.conditional_named_scope("condition_embedder"):
522+
(
523+
temb,
524+
timestep_proj,
525+
encoder_hidden_states,
526+
encoder_hidden_states_image,
527+
_,
528+
) = self.condition_embedder( # We will need to mask out the text embedding.
529+
timestep, encoder_hidden_states, encoder_hidden_states_image
530+
)
531+
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
482532

483533
if encoder_hidden_states_image is not None:
484534
raise NotImplementedError("img2vid is not yet implemented.")

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,6 @@ def _create_common_components(cls, config, vae_only=False, i2v=False):
630630
vae_devices_array = flat_devices.reshape(total_devices // vae_spatial, vae_spatial)
631631

632632
vae_mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial"))
633-
vae_mesh.vae_spatial_axis_name = "vae_spatial"
634633
max_logging.log(
635634
f"Created VAE specific mesh with axes ('redundant', 'vae_spatial') to support spatial sharding of {vae_spatial}."
636635
)

0 commit comments

Comments
 (0)