@@ -236,10 +236,10 @@ def __init__(
236236 )
237237
238238 def __call__ (self , hidden_states : jax .Array , deterministic : bool = True , rngs : nnx .Rngs = None ) -> jax .Array :
239- hidden_states = self .act_fn (hidden_states ) # Output is (4, 75600, 13824)
239+ hidden_states = self .act_fn (hidden_states ) # Output is (4, 75600, 13824)
240240 hidden_states = checkpoint_name (hidden_states , "ffn_activation" )
241241 hidden_states = self .drop_out (hidden_states , deterministic = deterministic , rngs = rngs )
242- return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
242+ return self .proj_out (hidden_states ) # output is (4, 75600, 5120)
243243
244244
245245class WanTransformerBlock (nnx .Module ):
@@ -281,7 +281,7 @@ def __init__(
281281 weights_dtype = weights_dtype ,
282282 precision = precision ,
283283 attention_kernel = attention ,
284- dropout = dropout
284+ dropout = dropout ,
285285 )
286286
287287 # 1. Cross-attention
@@ -299,7 +299,7 @@ def __init__(
299299 weights_dtype = weights_dtype ,
300300 precision = precision ,
301301 attention_kernel = attention ,
302- dropout = dropout
302+ dropout = dropout ,
303303 )
304304 assert cross_attn_norm is True
305305 self .norm2 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = True )
@@ -313,15 +313,24 @@ def __init__(
313313 dtype = dtype ,
314314 weights_dtype = weights_dtype ,
315315 precision = precision ,
316- dropout = dropout
316+ dropout = dropout ,
317317 )
318318 self .norm3 = FP32LayerNorm (rngs = rngs , dim = dim , eps = eps , elementwise_affine = False )
319319
320320 key = rngs .params ()
321321 self .adaln_scale_shift_table = nnx .Param (
322- jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 ,)
322+ jax .random .normal (key , (1 , 6 , dim )) / dim ** 0.5 ,
323+ )
323324
324- def __call__ (self , hidden_states : jax .Array , encoder_hidden_states : jax .Array , temb : jax .Array , rotary_emb : jax .Array , deterministic : bool = True , rngs : nnx .Rngs = None ,):
325+ def __call__ (
326+ self ,
327+ hidden_states : jax .Array ,
328+ encoder_hidden_states : jax .Array ,
329+ temb : jax .Array ,
330+ rotary_emb : jax .Array ,
331+ deterministic : bool = True ,
332+ rngs : nnx .Rngs = None ,
333+ ):
325334 shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = jnp .split (
326335 (self .adaln_scale_shift_table + temb ), 6 , axis = 1
327336 )
@@ -331,13 +340,19 @@ def __call__(self, hidden_states: jax.Array, encoder_hidden_states: jax.Array, t
331340 # 1. Self-attention
332341 norm_hidden_states = (self .norm1 (hidden_states ) * (1 + scale_msa ) + shift_msa ).astype (hidden_states .dtype )
333342 attn_output = self .attn1 (
334- hidden_states = norm_hidden_states , encoder_hidden_states = norm_hidden_states , rotary_emb = rotary_emb , deterministic = deterministic , rngs = rngs
343+ hidden_states = norm_hidden_states ,
344+ encoder_hidden_states = norm_hidden_states ,
345+ rotary_emb = rotary_emb ,
346+ deterministic = deterministic ,
347+ rngs = rngs ,
335348 )
336349 hidden_states = (hidden_states + attn_output * gate_msa ).astype (hidden_states .dtype )
337350
338351 # 2. Cross-attention
339352 norm_hidden_states = self .norm2 (hidden_states )
340- attn_output = self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states , deterministic = deterministic , rngs = rngs )
353+ attn_output = self .attn2 (
354+ hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states , deterministic = deterministic , rngs = rngs
355+ )
341356 hidden_states = hidden_states + attn_output
342357
343358 # 3. Feed-forward
@@ -380,7 +395,7 @@ def __init__(
380395 attention : str = "dot_product" ,
381396 remat_policy : str = "None" ,
382397 names_which_can_be_saved : list = [],
383- names_which_can_be_offloaded : list = []
398+ names_which_can_be_offloaded : list = [],
384399 ):
385400 inner_dim = num_attention_heads * attention_head_dim
386401 out_channels = out_channels or in_channels
@@ -417,7 +432,7 @@ def __init__(
417432
418433 # 3. Transformer blocks
419434 @nnx .split_rngs (splits = num_layers )
420- @nnx .vmap (in_axes = 0 , out_axes = 0 , transform_metadata = {nnx .PARTITION_NAME : "layers_per_stage" } )
435+ @nnx .vmap (in_axes = 0 , out_axes = 0 , transform_metadata = {nnx .PARTITION_NAME : "layers_per_stage" })
421436 def init_block (rngs ):
422437 return WanTransformerBlock (
423438 rngs = rngs ,
@@ -496,7 +511,9 @@ def scan_fn(carry, block):
496511 new_carry = (hidden_states , rngs_carry )
497512 return new_carry , None
498513
499- rematted_block_forward = self .gradient_checkpoint .apply (scan_fn , self .names_which_can_be_saved , self .names_which_can_be_offloaded )
514+ rematted_block_forward = self .gradient_checkpoint .apply (
515+ scan_fn , self .names_which_can_be_saved , self .names_which_can_be_offloaded
516+ )
500517 initial_carry = (hidden_states , rngs )
501518 final_carry , _ = nnx .scan (
502519 rematted_block_forward ,
0 commit comments