Skip to content

Commit fa877ab

Browse files
committed
Fix mesh axis bugs & Fix Wan-VACE transformer and pipeline to align with other Wan models and pipeline
* Remove an unnecessary `vae_spatial_axis_name` assignment to `vae_mesh` in `wan_pipeline.py`, that fixes `AttributeError: cannot assign to field 'vae_spatial_axis_name'`. * Adopts the shared `_create_common_components` method (used by `WanPipeline2_1`, `WanPipeline2_2`, etc.) into `_load_and_init` of `VaceWanPipeline2_1`. This align initialization and resolves mesh axis handling issues during `VaceWanPipeline2_1` initialization introduced by PR #359. * Adopts new parameters `mask_padding_tokens`, `enable_jax_named_scopes`, `use_base2_exp`, `use_experimental_scheduler` into `transformer_wan_vace.py` and `wan_vace_pipeline_2_1.py` (following `transformer_wan.py` and `wan_pipeline.py` respectively) that resolves attention errors (introduced by PR #359) and align Wan-VACE transformer to Wan transformer. * Removes the `load_common_components` argument from `from_pretrained` and `from_checkpoint` in `VaeWanPipeline2_1` for interface consistency.
1 parent c98002f commit fa877ab

3 files changed

Lines changed: 152 additions & 122 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan_vace.py

Lines changed: 124 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
limitations under the License.
1414
"""
1515

16+
import contextlib
1617
import math
1718
from typing import Any, Dict, Optional, Tuple
1819

@@ -21,7 +22,6 @@
2122
import jax
2223
from jax.ad_checkpoint import checkpoint_name
2324
import jax.numpy as jnp
24-
from jax.sharding import PartitionSpec
2525

2626
from .... import common_types
2727
from ....configuration_utils import register_to_config
@@ -62,8 +62,12 @@ def __init__(
6262
precision: jax.lax.Precision | None = None,
6363
attention: str = "dot_product",
6464
dropout: float = 0.0,
65+
mask_padding_tokens: bool = True,
66+
enable_jax_named_scopes: bool = False,
6567
apply_input_projection: bool = False,
6668
apply_output_projection: bool = False,
69+
use_base2_exp: bool = False,
70+
use_experimental_scheduler: bool = False,
6771
):
6872
"""Sets up the model.
6973
@@ -90,7 +94,7 @@ def __init__(
9094
apply_output_projection: Whether to apply an output projection before
9195
outputting the result.
9296
"""
93-
97+
self.enable_jax_named_scopes = enable_jax_named_scopes
9498
self.apply_input_projection = apply_input_projection
9599
self.apply_output_projection = apply_output_projection
96100

@@ -124,7 +128,12 @@ def __init__(
124128
precision=precision,
125129
attention_kernel=attention,
126130
dropout=dropout,
131+
is_self_attention=True,
132+
mask_padding_tokens=mask_padding_tokens,
127133
residual_checkpoint_name="self_attn",
134+
enable_jax_named_scopes=enable_jax_named_scopes,
135+
use_base2_exp=use_base2_exp,
136+
use_experimental_scheduler=use_experimental_scheduler,
128137
)
129138

130139
# 3. Cross-attention
@@ -143,7 +152,12 @@ def __init__(
143152
precision=precision,
144153
attention_kernel=attention,
145154
dropout=dropout,
155+
is_self_attention=False,
156+
mask_padding_tokens=mask_padding_tokens,
146157
residual_checkpoint_name="cross_attn",
158+
enable_jax_named_scopes=enable_jax_named_scopes,
159+
use_base2_exp=use_base2_exp,
160+
use_experimental_scheduler=use_experimental_scheduler,
147161
)
148162
assert cross_attn_norm is True, "cross_attn_norm must be True"
149163
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -158,6 +172,7 @@ def __init__(
158172
weights_dtype=weights_dtype,
159173
precision=precision,
160174
dropout=dropout,
175+
enable_jax_named_scopes=enable_jax_named_scopes,
161176
)
162177

163178
self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False)
@@ -180,6 +195,10 @@ def __init__(
180195
jax.random.normal(key, (1, 6, dim)) / dim**0.5,
181196
)
182197

198+
def conditional_named_scope(self, name: str):
199+
"""Return a JAX named scope if enabled, otherwise a null context."""
200+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
201+
183202
def __call__(
184203
self,
185204
*,
@@ -191,65 +210,74 @@ def __call__(
191210
deterministic: bool = True,
192211
rngs: nnx.Rngs | None = None,
193212
) -> Tuple[jax.Array, jax.Array]:
194-
if self.apply_input_projection:
195-
control_hidden_states = self.proj_in(control_hidden_states)
196-
control_hidden_states = control_hidden_states + hidden_states
197-
198-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
199-
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
200-
)
201-
202-
control_hidden_states = jax.lax.with_sharding_constraint(
203-
control_hidden_states,
204-
PartitionSpec("data", "fsdp", "tensor"),
205-
)
206-
control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states")
207-
encoder_hidden_states = jax.lax.with_sharding_constraint(
208-
encoder_hidden_states,
209-
PartitionSpec("data", "fsdp", None),
210-
)
211-
212-
# 1. Self-attention
213-
with jax.named_scope("attn1"):
214-
norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
215-
control_hidden_states.dtype
213+
with self.conditional_named_scope("vace_transformer_block"):
214+
with self.conditional_named_scope("input_projection"):
215+
if self.apply_input_projection:
216+
control_hidden_states = self.proj_in(control_hidden_states)
217+
control_hidden_states = control_hidden_states + hidden_states
218+
219+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
220+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
216221
)
217-
attn_output = self.attn1(
218-
hidden_states=norm_hidden_states,
219-
encoder_hidden_states=norm_hidden_states,
220-
rotary_emb=rotary_emb,
221-
deterministic=deterministic,
222-
rngs=rngs,
223-
)
224-
control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(
225-
control_hidden_states.dtype
226-
)
227-
228-
# 2. Cross-attention
229-
with jax.named_scope("attn2"):
230-
norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype)
231-
attn_output = self.attn2(
232-
hidden_states=norm_hidden_states,
233-
encoder_hidden_states=encoder_hidden_states,
234-
deterministic=deterministic,
235-
rngs=rngs,
236-
)
237-
control_hidden_states = control_hidden_states + attn_output
238222

239-
# 3. Feed-forward
240-
with jax.named_scope("ffn"):
241-
norm_hidden_states = (self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
242-
control_hidden_states.dtype
243-
)
244-
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
245-
control_hidden_states = (
246-
control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa
247-
).astype(control_hidden_states.dtype)
248-
conditioning_states = None
249-
if self.apply_output_projection:
250-
conditioning_states = self.proj_out(control_hidden_states)
223+
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads"))
224+
control_hidden_states = jax.lax.with_sharding_constraint(control_hidden_states, axis_names)
225+
control_hidden_states = checkpoint_name(control_hidden_states, "control_hidden_states")
226+
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv"))
227+
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
228+
229+
# 1. Self-attention
230+
with self.conditional_named_scope("self_attn"):
231+
with self.conditional_named_scope("self_attn_norm"):
232+
norm_hidden_states = (self.norm1(control_hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
233+
control_hidden_states.dtype
234+
)
235+
with self.conditional_named_scope("self_attn_attn"):
236+
attn_output = self.attn1(
237+
hidden_states=norm_hidden_states,
238+
encoder_hidden_states=norm_hidden_states,
239+
rotary_emb=rotary_emb,
240+
deterministic=deterministic,
241+
rngs=rngs,
242+
)
243+
with self.conditional_named_scope("self_attn_residual"):
244+
control_hidden_states = (control_hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(
245+
control_hidden_states.dtype
246+
)
251247

252-
return conditioning_states, control_hidden_states
248+
# 2. Cross-attention
249+
with self.conditional_named_scope("cross_attn"):
250+
with self.conditional_named_scope("cross_attn_norm"):
251+
norm_hidden_states = self.norm2(control_hidden_states.astype(jnp.float32)).astype(control_hidden_states.dtype)
252+
with self.conditional_named_scope("cross_attn_attn"):
253+
attn_output = self.attn2(
254+
hidden_states=norm_hidden_states,
255+
encoder_hidden_states=encoder_hidden_states,
256+
deterministic=deterministic,
257+
rngs=rngs,
258+
)
259+
with self.conditional_named_scope("cross_attn_residual"):
260+
control_hidden_states = control_hidden_states + attn_output
261+
262+
# 3. Feed-forward
263+
with self.conditional_named_scope("mlp"):
264+
with self.conditional_named_scope("mlp_norm"):
265+
norm_hidden_states = (
266+
self.norm3(control_hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa
267+
).astype(control_hidden_states.dtype)
268+
with self.conditional_named_scope("mlp_ffn"):
269+
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
270+
with self.conditional_named_scope("mlp_residual"):
271+
control_hidden_states = (
272+
control_hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa
273+
).astype(control_hidden_states.dtype)
274+
275+
with self.conditional_named_scope("output_projection"):
276+
conditioning_states = None
277+
if self.apply_output_projection:
278+
conditioning_states = self.proj_out(control_hidden_states)
279+
280+
return conditioning_states, control_hidden_states
253281

254282

255283
class WanVACEModel(WanModel):
@@ -289,7 +317,11 @@ def __init__(
289317
remat_policy: str = "None",
290318
names_which_can_be_saved: list[str] = [],
291319
names_which_can_be_offloaded: list[str] = [],
320+
mask_padding_tokens: bool = True,
292321
scan_layers: bool = True,
322+
enable_jax_named_scopes: bool = False,
323+
use_base2_exp: bool = False,
324+
use_experimental_scheduler: bool = False,
293325
):
294326
"""Initializes the VACE model.
295327
@@ -302,6 +334,7 @@ def __init__(
302334
out_channels = out_channels or in_channels
303335
self.num_layers = num_layers
304336
self.scan_layers = scan_layers
337+
self.enable_jax_named_scopes = enable_jax_named_scopes
305338

306339
# 1. Patch & position embedding
307340
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
@@ -329,6 +362,7 @@ def __init__(
329362
text_embed_dim=text_dim,
330363
image_embed_dim=image_dim,
331364
pos_embed_seq_len=pos_embed_seq_len,
365+
flash_min_seq_length=flash_min_seq_length,
332366
)
333367

334368
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
@@ -358,6 +392,10 @@ def __init__(
358392
precision=precision,
359393
attention=attention,
360394
dropout=dropout,
395+
mask_padding_tokens=mask_padding_tokens,
396+
enable_jax_named_scopes=enable_jax_named_scopes,
397+
use_base2_exp=use_base2_exp,
398+
use_experimental_scheduler=use_experimental_scheduler,
361399
)
362400
blocks.append(block)
363401
self.blocks = blocks
@@ -384,8 +422,12 @@ def __init__(
384422
precision=precision,
385423
attention=attention,
386424
dropout=dropout,
425+
mask_padding_tokens=mask_padding_tokens,
426+
enable_jax_named_scopes=enable_jax_named_scopes,
387427
apply_input_projection=vace_block_id == 0,
388428
apply_output_projection=True,
429+
use_base2_exp=use_base2_exp,
430+
use_experimental_scheduler=use_experimental_scheduler,
389431
)
390432
vace_blocks.append(vace_block)
391433
self.vace_blocks = vace_blocks
@@ -421,6 +463,10 @@ def __init__(
421463
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")),
422464
)
423465

466+
def conditional_named_scope(self, name: str):
467+
"""Return a JAX named scope if enabled, otherwise a null context."""
468+
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
469+
424470
@jax.named_scope("WanVACEModel")
425471
def __call__(
426472
self,
@@ -436,7 +482,7 @@ def __call__(
436482
rngs: nnx.Rngs = None,
437483
) -> jax.Array:
438484
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
439-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
485+
batch_size, _, num_frames, height, width = hidden_states.shape
440486
p_t, p_h, p_w = self.config.patch_size
441487
post_patch_num_frames = num_frames // p_t
442488
post_patch_height = height // p_h
@@ -453,13 +499,14 @@ def __call__(
453499

454500
hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1))
455501
control_hidden_states = jnp.transpose(control_hidden_states, (0, 2, 3, 4, 1))
456-
rotary_emb = self.rope(hidden_states)
457-
458-
hidden_states = self.patch_embedding(hidden_states)
459-
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
460-
461-
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
462-
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
502+
with self.conditional_named_scope("rotary_embedding"):
503+
rotary_emb = self.rope(hidden_states)
504+
with self.conditional_named_scope("patch_embedding"):
505+
hidden_states = self.patch_embedding(hidden_states)
506+
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
507+
508+
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
509+
control_hidden_states = jax.lax.collapse(control_hidden_states, 1, -1)
463510
control_hidden_states_padding = jnp.zeros((
464511
batch_size,
465512
control_hidden_states.shape[1],
@@ -469,16 +516,17 @@ def __call__(
469516
control_hidden_states = jnp.concatenate([control_hidden_states, control_hidden_states_padding], axis=2)
470517

471518
# Condition embedder is a FC layer.
472-
(
473-
temb,
474-
timestep_proj,
475-
encoder_hidden_states,
476-
encoder_hidden_states_image,
477-
_,
478-
) = self.condition_embedder( # We will need to mask out the text embedding.
479-
timestep, encoder_hidden_states, encoder_hidden_states_image
480-
)
481-
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
519+
with self.conditional_named_scope("condition_embedder"):
520+
(
521+
temb,
522+
timestep_proj,
523+
encoder_hidden_states,
524+
encoder_hidden_states_image,
525+
_,
526+
) = self.condition_embedder( # We will need to mask out the text embedding.
527+
timestep, encoder_hidden_states, encoder_hidden_states_image
528+
)
529+
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
482530

483531
if encoder_hidden_states_image is not None:
484532
raise NotImplementedError("img2vid is not yet implemented.")

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,6 @@ def _create_common_components(cls, config, vae_only=False, i2v=False):
630630
vae_devices_array = flat_devices.reshape(total_devices // vae_spatial, vae_spatial)
631631

632632
vae_mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial"))
633-
vae_mesh.vae_spatial_axis_name = "vae_spatial"
634633
max_logging.log(
635634
f"Created VAE specific mesh with axes ('redundant', 'vae_spatial') to support spatial sharding of {vae_spatial}."
636635
)

0 commit comments

Comments
 (0)