@@ -142,8 +142,8 @@ def __call__(
142142 ):
143143 timestep = self .timesteps_proj (timestep )
144144 temb = self .time_embedder (timestep )
145-
146- timestep_proj = self .time_proj (self .act_fn (temb ))
145+ with jax . named_scope ( "time_proj" ):
146+ timestep_proj = self .time_proj (self .act_fn (temb ))
147147
148148 encoder_hidden_states = self .text_embedder (encoder_hidden_states )
149149 if encoder_hidden_states_image is not None :
@@ -186,7 +186,8 @@ def __init__(
186186 )
187187
188188 def __call__ (self , x : jax .Array ) -> jax .Array :
189- x = self .proj (x )
189+ with jax .named_scope ("gelu" ):
190+ x = self .proj (x )
190191 return nnx .gelu (x )
191192
192193
@@ -244,12 +245,11 @@ def conditional_named_scope(self, name: str):
244245 return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
245246
246247 def __call__ (self , hidden_states : jax .Array , deterministic : bool = True , rngs : nnx .Rngs = None ) -> jax .Array :
247- with self .conditional_named_scope ("mlp_up_proj_and_gelu" ):
248248 hidden_states = self .act_fn (hidden_states ) # Output is (4, 75600, 13824)
249249 hidden_states = checkpoint_name (hidden_states , "ffn_activation" )
250250 hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
251- with self . conditional_named_scope ( "mlp_down_proj " ):
252- return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
251+ with jax . named_scope ( "proj_out " ):
252+ return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
253253
254254
255255class WanTransformerBlock (nnx .Module ):
@@ -354,48 +354,42 @@ def __call__(
354354 rngs : nnx .Rngs = None ,
355355 ):
356356 with self .conditional_named_scope ("transformer_block" ):
357- with self .conditional_named_scope ("adaln" ):
358- shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
359- (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
360- )
357+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
358+ (self .adaln_scale_shift_table + temb .astype (jnp .float32 )), 6 , axis = 1
359+ )
361360 hidden_states = jax .lax .with_sharding_constraint (hidden_states , PartitionSpec ("data" , "fsdp" , "tensor" ))
362361 hidden_states = checkpoint_name (hidden_states , "hidden_states" )
363362 encoder_hidden_states = jax .lax .with_sharding_constraint (encoder_hidden_states , PartitionSpec ("data" , "fsdp" , None ))
364363
365364 # 1. Self-attention
366- with self .conditional_named_scope ("self_attn" ):
367- with self .conditional_named_scope ("self_attn_norm" ):
365+ with jax .named_scope ("attn1" ):
368366 norm_hidden_states = (self .norm1 (hidden_states .astype (jnp .float32 )) * (1 + scale_msa ) + shift_msa ).astype (
369367 hidden_states .dtype
370368 )
371- with self .conditional_named_scope ("self_attn_attn" ):
372369 attn_output = self .attn1 (
373370 hidden_states = norm_hidden_states ,
374371 encoder_hidden_states = norm_hidden_states ,
375372 rotary_emb = rotary_emb ,
376373 deterministic = deterministic ,
377374 rngs = rngs ,
378375 )
379- with self .conditional_named_scope ("self_attn_residual" ):
380376 hidden_states = (hidden_states .astype (jnp .float32 ) + attn_output * gate_msa ).astype (hidden_states .dtype )
381377
382378 # 2. Cross-attention
383- norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 )).astype (hidden_states .dtype )
384- attn_output = self .attn2 (
385- hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states , deterministic = deterministic , rngs = rngs
386- )
387- hidden_states = hidden_states + attn_output
379+ with jax .named_scope ('attn2' ):
380+ norm_hidden_states = self .norm2 (hidden_states .astype (jnp .float32 )).astype (hidden_states .dtype )
381+ attn_output = self .attn2 (
382+ hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states , deterministic = deterministic , rngs = rngs
383+ )
384+ hidden_states = hidden_states + attn_output
388385
389386 # 3. Feed-forward
390- with self .conditional_named_scope ("mlp" ):
391- with self .conditional_named_scope ("mlp_norm" ):
387+ with jax .named_scope ("ffn" ):
392388 norm_hidden_states = (self .norm3 (hidden_states .astype (jnp .float32 )) * (1 + c_scale_msa ) + c_shift_msa ).astype (
393389 hidden_states .dtype
394390 )
395- with self .conditional_named_scope ("mlp_ffn" ):
396- ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
397- with self .conditional_named_scope ("mlp_residual" ):
398- hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (
391+ ff_output = self .ffn (norm_hidden_states , deterministic = deterministic , rngs = rngs )
392+ hidden_states = (hidden_states .astype (jnp .float32 ) + ff_output .astype (jnp .float32 ) * c_gate_msa ).astype (
399393 hidden_states .dtype
400394 )
401395 return hidden_states
@@ -543,6 +537,7 @@ def conditional_named_scope(self, name: str):
543537 """Return a JAX named scope if enabled, otherwise a null context."""
544538 return jax .named_scope (name ) if self .enable_jax_named_scopes else contextlib .nullcontext ()
545539
540+ @jax .named_scope ('WanModel' )
546541 def __call__ (
547542 self ,
548543 hidden_states : jax .Array ,
@@ -609,9 +604,8 @@ def layer_forward(hidden_states):
609604 hidden_states = rematted_layer_forward (hidden_states )
610605
611606 shift , scale = jnp .split (self .scale_shift_table + jnp .expand_dims (temb , axis = 1 ), 2 , axis = 1 )
612- with self .conditional_named_scope ("output_norm" ):
613- hidden_states = (self .norm_out (hidden_states .astype (jnp .float32 )) * (1 + scale ) + shift ).astype (hidden_states .dtype )
614- with self .conditional_named_scope ("output_proj" ):
607+ hidden_states = (self .norm_out (hidden_states .astype (jnp .float32 )) * (1 + scale ) + shift ).astype (hidden_states .dtype )
608+ with jax .named_scope ("proj_out" ):
615609 hidden_states = self .proj_out (hidden_states )
616610
617611 hidden_states = hidden_states .reshape (
0 commit comments