From cda888dca989080f5f85631617d84104f3203d36 Mon Sep 17 00:00:00 2001 From: eltsai Date: Wed, 19 Nov 2025 23:49:37 +0000 Subject: [PATCH 1/4] Added named_scope for WAN 2.1 Xprof profiling --- src/maxdiffusion/models/attention_flax.py | 40 +++--- .../wan/transformers/transformer_wan.py | 119 ++++++++++-------- 2 files changed, 94 insertions(+), 65 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 6a578899e..db69c399a 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -965,30 +965,42 @@ def __call__( dtype = hidden_states.dtype if encoder_hidden_states is None: encoder_hidden_states = hidden_states - - query_proj = self.query(hidden_states) - key_proj = self.key(encoder_hidden_states) - value_proj = self.value(encoder_hidden_states) + + with jax.named_scope("attn_qkv_proj"): + with jax.named_scope("proj_query"): + query_proj = self.query(hidden_states) + with jax.named_scope("proj_key"): + key_proj = self.key(encoder_hidden_states) + with jax.named_scope("proj_value"): + value_proj = self.value(encoder_hidden_states) if self.qk_norm: - query_proj = self.norm_q(query_proj) - key_proj = self.norm_k(key_proj) + with jax.named_scope("attn_q_norm"): + query_proj = self.norm_q(query_proj) + with jax.named_scope("attn_k_norm"): + key_proj = self.norm_k(key_proj) + if rotary_emb is not None: - query_proj = _unflatten_heads(query_proj, self.heads) - key_proj = _unflatten_heads(key_proj, self.heads) - value_proj = _unflatten_heads(value_proj, self.heads) - # output of _unflatten_heads Batch, heads, seq_len, head_dim - query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) + with jax.named_scope("attn_rope"): + query_proj = _unflatten_heads(query_proj, self.heads) + key_proj = _unflatten_heads(key_proj, self.heads) + value_proj = _unflatten_heads(value_proj, self.heads) + # output of _unflatten_heads Batch, heads, seq_len, head_dim + query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) query_proj = checkpoint_name(query_proj, "query_proj") key_proj = checkpoint_name(key_proj, "key_proj") value_proj = checkpoint_name(value_proj, "value_proj") - attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) + + with jax.named_scope("attn_compute"): + attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) attn_output = attn_output.astype(dtype=dtype) attn_output = checkpoint_name(attn_output, "attn_output") - hidden_states = self.proj_attn(attn_output) - hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) + + with jax.named_scope("attn_out_proj"): + hidden_states = self.proj_attn(attn_output) + hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) return hidden_states diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 4dc21d432..f411a862b 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -237,10 +237,12 @@ def __init__( ) def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: - hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) - hidden_states = checkpoint_name(hidden_states, "ffn_activation") - hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) - return self.proj_out(hidden_states) # output is (4, 75600, 5120) + with jax.named_scope("mlp_up_proj_and_gelu"): + hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) + hidden_states = checkpoint_name(hidden_states, "ffn_activation") + hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) + with jax.named_scope("mlp_down_proj"): + return self.proj_out(hidden_states) # output is (4, 75600, 5120) class WanTransformerBlock(nnx.Module): @@ -339,45 +341,59 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, ): - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( - (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 - ) - hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) - hidden_states = checkpoint_name(hidden_states, "hidden_states") - encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) - - # 1. Self-attention - norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( - hidden_states.dtype - ) - attn_output = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_hidden_states, - rotary_emb=rotary_emb, - deterministic=deterministic, - rngs=rngs, - ) - hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) - - # 2. Cross-attention - norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype) - attn_output = self.attn2( - hidden_states=norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - deterministic=deterministic, - rngs=rngs, - ) - hidden_states = hidden_states + attn_output - - # 3. Feed-forward - norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( - hidden_states.dtype - ) - ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) - hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype( - hidden_states.dtype - ) - return hidden_states + with jax.named_scope("transformer_block"): + with jax.named_scope("adaln"): + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( + (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 + ) + hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor")) + hidden_states = checkpoint_name(hidden_states, "hidden_states") + encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) + + # 1. Self-attention + with jax.named_scope("self_attn"): + with jax.named_scope("self_attn_norm"): + norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( + hidden_states.dtype + ) + with jax.named_scope("self_attn_attn"): + attn_output = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + rotary_emb=rotary_emb, + deterministic=deterministic, + rngs=rngs, + ) + with jax.named_scope("self_attn_residual"): + hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) + + # 2. Cross-attention + with jax.named_scope("cross_attn"): + with jax.named_scope("cross_attn_norm"): + norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype) + with jax.named_scope("cross_attn_attn"): + attn_output = self.attn2( + hidden_states=norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + deterministic=deterministic, + rngs=rngs, + ) + with jax.named_scope("cross_attn_residual"): + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + with jax.named_scope("mlp"): + with jax.named_scope("mlp_norm"): + norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( + hidden_states.dtype + ) + with jax.named_scope("mlp_ffn"): + ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) + with jax.named_scope("mlp_residual"): + hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype( + hidden_states.dtype + ) + return hidden_states class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -536,14 +552,15 @@ def __call__( post_patch_width = width // p_w hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) - rotary_emb = self.rope(hidden_states) - with jax.named_scope("PatchEmbedding"): + with jax.named_scope("rotary_embedding"): + rotary_emb = self.rope(hidden_states) + with jax.named_scope("patch_embedding"): hidden_states = self.patch_embedding(hidden_states) - hidden_states = jax.lax.collapse(hidden_states, 1, -1) - - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, encoder_hidden_states, encoder_hidden_states_image - ) + hidden_states = jax.lax.collapse(hidden_states, 1, -1) + with jax.named_scope("condition_embedder"): + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) if encoder_hidden_states_image is not None: @@ -594,4 +611,4 @@ def layer_forward(hidden_states): hidden_states = jax.lax.collapse(hidden_states, 6, None) hidden_states = jax.lax.collapse(hidden_states, 4, 6) hidden_states = jax.lax.collapse(hidden_states, 2, 4) - return hidden_states + return hidden_states \ No newline at end of file From c31048f515b00859ebcb6f5acaead266fc2f4eff Mon Sep 17 00:00:00 2001 From: eltsai Date: Thu, 20 Nov 2025 08:59:59 +0000 Subject: [PATCH 2/4] Added flag to enable named_scope --- src/maxdiffusion/configs/base_wan_14b.yml | 4 ++ src/maxdiffusion/models/attention_flax.py | 25 ++++--- .../wan/transformers/transformer_wan.py | 65 +++++++++++++------ .../pipelines/wan/wan_pipeline.py | 1 + 4 files changed, 66 insertions(+), 29 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 78dca3be4..8e0729738 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -284,6 +284,10 @@ enable_profiler: False skip_first_n_steps_for_profiler: 5 profiler_steps: 10 +# Enable JAX named scopes for detailed profiling and debugging +# When enabled, adds named scopes around key operations in transformer and attention layers +enable_jax_named_scopes: False + # Generation parameters prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." prompt_2: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index db69c399a..02b51f9c2 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import functools import math from typing import Optional, Callable, Tuple @@ -805,6 +806,7 @@ def __init__( is_self_attention: bool = True, mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, + enable_jax_named_scopes: bool = False, ): if attention_kernel == "cudnn_flash_te": raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") @@ -820,6 +822,7 @@ def __init__( self.key_axis_names = key_axis_names self.value_axis_names = value_axis_names self.out_axis_names = out_axis_names + self.enable_jax_named_scopes = enable_jax_named_scopes if is_self_attention: axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV) @@ -952,6 +955,10 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup return xq_out, xk_out + def conditional_named_scope(self, name: str): + """Return a JAX named scope if enabled, otherwise a null context.""" + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def __call__( self, hidden_states: jax.Array, @@ -965,7 +972,7 @@ def __call__( dtype = hidden_states.dtype if encoder_hidden_states is None: encoder_hidden_states = hidden_states - + with jax.named_scope("attn_qkv_proj"): with jax.named_scope("proj_query"): query_proj = self.query(hidden_states) @@ -975,13 +982,13 @@ def __call__( value_proj = self.value(encoder_hidden_states) if self.qk_norm: - with jax.named_scope("attn_q_norm"): + with self.conditional_named_scope("attn_q_norm"): query_proj = self.norm_q(query_proj) - with jax.named_scope("attn_k_norm"): + with self.conditional_named_scope("attn_k_norm"): key_proj = self.norm_k(key_proj) - + if rotary_emb is not None: - with jax.named_scope("attn_rope"): + with self.conditional_named_scope("attn_rope"): query_proj = _unflatten_heads(query_proj, self.heads) key_proj = _unflatten_heads(key_proj, self.heads) value_proj = _unflatten_heads(value_proj, self.heads) @@ -991,14 +998,14 @@ def __call__( query_proj = checkpoint_name(query_proj, "query_proj") key_proj = checkpoint_name(key_proj, "key_proj") value_proj = checkpoint_name(value_proj, "value_proj") - - with jax.named_scope("attn_compute"): + + with self.conditional_named_scope("attn_compute"): attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) attn_output = attn_output.astype(dtype=dtype) attn_output = checkpoint_name(attn_output, "attn_output") - - with jax.named_scope("attn_out_proj"): + + with self.conditional_named_scope("attn_out_proj"): hidden_states = self.proj_attn(attn_output) hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) return hidden_states diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index f411a862b..8ff2cfefc 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -15,6 +15,7 @@ """ from typing import Tuple, Optional, Dict, Union, Any +import contextlib import math import jax import jax.numpy as jnp @@ -205,11 +206,13 @@ def __init__( dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None, + enable_jax_named_scopes: bool = False, ): if inner_dim is None: inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim + self.enable_jax_named_scopes = enable_jax_named_scopes self.act_fn = nnx.data(None) if activation_fn == "gelu-approximate": self.act_fn = ApproximateGELU( @@ -236,12 +239,16 @@ def __init__( ), ) + def conditional_named_scope(self, name: str): + """Return a JAX named scope if enabled, otherwise a null context.""" + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: - with jax.named_scope("mlp_up_proj_and_gelu"): + with self.conditional_named_scope("mlp_up_proj_and_gelu"): hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) hidden_states = checkpoint_name(hidden_states, "ffn_activation") hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) - with jax.named_scope("mlp_down_proj"): + with self.conditional_named_scope("mlp_down_proj"): return self.proj_out(hidden_states) # output is (4, 75600, 5120) @@ -267,8 +274,11 @@ def __init__( attention: str = "dot_product", dropout: float = 0.0, mask_padding_tokens: bool = True, + enable_jax_named_scopes: bool = False, ): + self.enable_jax_named_scopes = enable_jax_named_scopes + # 1. Self-attention self.norm1 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) self.attn1 = FlaxWanAttention( @@ -289,6 +299,7 @@ def __init__( is_self_attention=True, mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="self_attn", + enable_jax_named_scopes=enable_jax_named_scopes, ) # 1. Cross-attention @@ -310,6 +321,7 @@ def __init__( is_self_attention=False, mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="cross_attn", + enable_jax_named_scopes=enable_jax_named_scopes, ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -324,6 +336,7 @@ def __init__( weights_dtype=weights_dtype, precision=precision, dropout=dropout, + enable_jax_named_scopes=enable_jax_named_scopes, ) self.norm3 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=False) @@ -332,6 +345,10 @@ def __init__( jax.random.normal(key, (1, 6, dim)) / dim**0.5, ) + def conditional_named_scope(self, name: str): + """Return a JAX named scope if enabled, otherwise a null context.""" + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def __call__( self, hidden_states: jax.Array, @@ -341,8 +358,8 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, ): - with jax.named_scope("transformer_block"): - with jax.named_scope("adaln"): + with self.conditional_named_scope("transformer_block"): + with self.conditional_named_scope("adaln"): shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1 ) @@ -351,12 +368,12 @@ def __call__( encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None)) # 1. Self-attention - with jax.named_scope("self_attn"): - with jax.named_scope("self_attn_norm"): + with self.conditional_named_scope("self_attn"): + with self.conditional_named_scope("self_attn_norm"): norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype( hidden_states.dtype ) - with jax.named_scope("self_attn_attn"): + with self.conditional_named_scope("self_attn_attn"): attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, @@ -364,32 +381,32 @@ def __call__( deterministic=deterministic, rngs=rngs, ) - with jax.named_scope("self_attn_residual"): + with self.conditional_named_scope("self_attn_residual"): hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype) # 2. Cross-attention - with jax.named_scope("cross_attn"): - with jax.named_scope("cross_attn_norm"): + with self.conditional_named_scope("cross_attn"): + with self.conditional_named_scope("cross_attn_norm"): norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype) - with jax.named_scope("cross_attn_attn"): + with self.conditional_named_scope("cross_attn_attn"): attn_output = self.attn2( hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs, ) - with jax.named_scope("cross_attn_residual"): + with self.conditional_named_scope("cross_attn_residual"): hidden_states = hidden_states + attn_output # 3. Feed-forward - with jax.named_scope("mlp"): - with jax.named_scope("mlp_norm"): + with self.conditional_named_scope("mlp"): + with self.conditional_named_scope("mlp_norm"): norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype( hidden_states.dtype ) - with jax.named_scope("mlp_ffn"): + with self.conditional_named_scope("mlp_ffn"): ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs) - with jax.named_scope("mlp_residual"): + with self.conditional_named_scope("mlp_residual"): hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype( hidden_states.dtype ) @@ -432,11 +449,13 @@ def __init__( names_which_can_be_offloaded: list = [], mask_padding_tokens: bool = True, scan_layers: bool = True, + enable_jax_named_scopes: bool = False, ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels self.num_layers = num_layers self.scan_layers = scan_layers + self.enable_jax_named_scopes = enable_jax_named_scopes # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -488,6 +507,7 @@ def init_block(rngs): attention=attention, dropout=dropout, mask_padding_tokens=mask_padding_tokens, + enable_jax_named_scopes=enable_jax_named_scopes, ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) @@ -513,6 +533,7 @@ def init_block(rngs): weights_dtype=weights_dtype, precision=precision, attention=attention, + enable_jax_named_scopes=enable_jax_named_scopes, ) blocks.append(block) self.blocks = blocks @@ -533,6 +554,10 @@ def init_block(rngs): kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")), ) + def conditional_named_scope(self, name: str): + """Return a JAX named scope if enabled, otherwise a null context.""" + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def __call__( self, hidden_states: jax.Array, @@ -552,12 +577,12 @@ def __call__( post_patch_width = width // p_w hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) - with jax.named_scope("rotary_embedding"): + with self.conditional_named_scope("rotary_embedding"): rotary_emb = self.rope(hidden_states) - with jax.named_scope("patch_embedding"): + with self.conditional_named_scope("patch_embedding"): hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) - with jax.named_scope("condition_embedder"): + with self.conditional_named_scope("condition_embedder"): temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) @@ -611,4 +636,4 @@ def layer_forward(hidden_states): hidden_states = jax.lax.collapse(hidden_states, 6, None) hidden_states = jax.lax.collapse(hidden_states, 4, 6) hidden_states = jax.lax.collapse(hidden_states, 2, 4) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 7ed8007b3..9068d2568 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -114,6 +114,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["dropout"] = config.dropout wan_config["mask_padding_tokens"] = config.mask_padding_tokens wan_config["scan_layers"] = config.scan_layers + wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes # 2. eval_shape - will not use flops or create weights on device # thus not using HBM memory. From f28ef839e1faa806479a8d4a025e18db37d95f8e Mon Sep 17 00:00:00 2001 From: eltsai Date: Thu, 20 Nov 2025 20:17:00 +0000 Subject: [PATCH 3/4] Make named_scope in flash_attention respect enable_jax_named_scopes flag --- src/maxdiffusion/models/attention_flax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 02b51f9c2..cfe3c1fc1 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -973,12 +973,12 @@ def __call__( if encoder_hidden_states is None: encoder_hidden_states = hidden_states - with jax.named_scope("attn_qkv_proj"): - with jax.named_scope("proj_query"): + with self.conditional_named_scope("attn_qkv_proj"): + with self.conditional_named_scope("proj_query"): query_proj = self.query(hidden_states) - with jax.named_scope("proj_key"): + with self.conditional_named_scope("proj_key"): key_proj = self.key(encoder_hidden_states) - with jax.named_scope("proj_value"): + with self.conditional_named_scope("proj_value"): value_proj = self.value(encoder_hidden_states) if self.qk_norm: From 496875de69db248a947b9d26d3efb3ae905b9958 Mon Sep 17 00:00:00 2001 From: eltsai Date: Fri, 21 Nov 2025 00:49:29 +0000 Subject: [PATCH 4/4] Added named scope for WanModel output --- .../models/wan/transformers/transformer_wan.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 8ff2cfefc..5d7aec101 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -625,9 +625,10 @@ def layer_forward(hidden_states): hidden_states = rematted_layer_forward(hidden_states) shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) - - hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) - hidden_states = self.proj_out(hidden_states) + with self.conditional_named_scope("output_norm"): + hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) + with self.conditional_named_scope("output_proj"): + hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1