@@ -391,7 +391,7 @@ def __call__(
391391 new_cache = {}
392392
393393 if self .mode == "upsample2d" :
394- b , t , h , w , c = x .shape
394+ b , t , h , w , c = x .shape
395395 x = x .reshape (b * t , h , w , c )
396396 x = self .resample (x )
397397 h_new , w_new , c_new = x .shape [1 :]
@@ -403,14 +403,14 @@ def __call__(
403403
404404 b , t , h , w , c = x .shape
405405 x = x .reshape (b , t , h , w , 2 , c // 2 )
406- x = jnp .stack ([x [:, :, :, :, 0 , :], x [:, :, :, :, 1 , :]], axis = 1 )
406+ x = jnp .stack ([x [:, :, :, :, 0 , :], x [:, :, :, :, 1 , :]], axis = 1 )
407407 x = x .reshape (b , t * 2 , h , w , c // 2 )
408408
409409 b , t , h , w , c = x .shape
410- x = x .reshape (b * t , h , w , c )
411- x = self .resample (x )
412- h_new , w_new , c_new = x .shape [1 :]
413- x = x .reshape (b , t , h_new , w_new , c_new )
410+ x = x .reshape (b * t , h , w , c )
411+ x = self .resample (x )
412+ h_new , w_new , c_new = x .shape [1 :]
413+ x = x .reshape (b , t , h_new , w_new , c_new )
414414
415415 elif self .mode == "downsample2d" :
416416 b , t , h , w , c = x .shape
@@ -429,7 +429,7 @@ def __call__(
429429 x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
430430 new_cache ["time_conv" ] = tc_cache
431431
432- else :
432+ else :
433433 if hasattr (self , "resample" ):
434434 if isinstance (self .resample , Identity ):
435435 x , _ = self .resample (x , None )
@@ -485,15 +485,15 @@ def __init__(
485485 self .conv_shortcut = Identity ()
486486 if in_dim != out_dim :
487487 self .conv_shortcut = WanCausalConv3d (
488- rngs = rngs ,
489- in_channels = in_dim ,
490- out_channels = out_dim ,
491- kernel_size = 1 ,
492- mesh = mesh ,
493- dtype = dtype ,
494- weights_dtype = weights_dtype ,
495- precision = precision ,
496- )
488+ rngs = rngs ,
489+ in_channels = in_dim ,
490+ out_channels = out_dim ,
491+ kernel_size = 1 ,
492+ mesh = mesh ,
493+ dtype = dtype ,
494+ weights_dtype = weights_dtype ,
495+ precision = precision ,
496+ )
497497
498498 def initialize_cache (self , batch_size , height , width , dtype ):
499499 """Initialize cache for all convolutions."""
@@ -572,42 +572,42 @@ def __init__(
572572 )
573573
574574 def __call__ (self , x : jax .Array ):
575- identity = x
576- batch_size , time , height , width , channels = x .shape
577-
578- # Reshape to process all frames together
579- x = x .reshape (batch_size * time , height , width , channels )
580- x = self .norm (x )
581-
582- qkv = self .to_qkv (x ) # (B*T, H, W, C*3)
583-
584- # Get actual shape after to_qkv to avoid using stale variables
585- bt , h , w , c3 = qkv .shape
586-
587- # Flatten spatial dimensions for attention
588- qkv = qkv .reshape (bt , h * w , c3 ) # (B*T, H*W, C*3)
589- qkv = jnp .transpose (qkv , (0 , 2 , 1 )) # (B*T, C*3, H*W)
590-
591- q , k , v = jnp .split (qkv , 3 , axis = 1 ) # Each: (B*T, C, H*W)
592- q = jnp .transpose (q , (0 , 2 , 1 )) # (B*T, H*W, C)
593- k = jnp .transpose (k , (0 , 2 , 1 )) # (B*T, H*W, C)
594- v = jnp .transpose (v , (0 , 2 , 1 )) # (B*T, H*W, C)
595-
596- # Add head dimension for dot_product_attention
597- q = jnp .expand_dims (q , 1 ) # (B*T, 1, H*W, C)
598- k = jnp .expand_dims (k , 1 ) # (B*T, 1, H*W, C)
599- v = jnp .expand_dims (v , 1 ) # (B*T, 1, H*W, C)
600-
601- x = jax .nn .dot_product_attention (q , k , v ) # (B*T, 1, H*W, C)
602- x = jnp .squeeze (x , 1 ) # (B*T, H*W, C)
603-
604- # Reshape back to spatial dimensions
605- x = x .reshape (bt , h , w , channels )
606- x = self .proj (x )
607-
608- # Reshape back to original shape
609- x = x .reshape (batch_size , time , height , width , channels )
610- return x + identity
575+ identity = x
576+ batch_size , time , height , width , channels = x .shape
577+
578+ # Reshape to process all frames together
579+ x = x .reshape (batch_size * time , height , width , channels )
580+ x = self .norm (x )
581+
582+ qkv = self .to_qkv (x ) # (B*T, H, W, C*3)
583+
584+ # Get actual shape after to_qkv to avoid using stale variables
585+ bt , h , w , c3 = qkv .shape
586+
587+ # Flatten spatial dimensions for attention
588+ qkv = qkv .reshape (bt , h * w , c3 ) # (B*T, H*W, C*3)
589+ qkv = jnp .transpose (qkv , (0 , 2 , 1 )) # (B*T, C*3, H*W)
590+
591+ q , k , v = jnp .split (qkv , 3 , axis = 1 ) # Each: (B*T, C, H*W)
592+ q = jnp .transpose (q , (0 , 2 , 1 )) # (B*T, H*W, C)
593+ k = jnp .transpose (k , (0 , 2 , 1 )) # (B*T, H*W, C)
594+ v = jnp .transpose (v , (0 , 2 , 1 )) # (B*T, H*W, C)
595+
596+ # Add head dimension for dot_product_attention
597+ q = jnp .expand_dims (q , 1 ) # (B*T, 1, H*W, C)
598+ k = jnp .expand_dims (k , 1 ) # (B*T, 1, H*W, C)
599+ v = jnp .expand_dims (v , 1 ) # (B*T, 1, H*W, C)
600+
601+ x = jax .nn .dot_product_attention (q , k , v ) # (B*T, 1, H*W, C)
602+ x = jnp .squeeze (x , 1 ) # (B*T, H*W, C)
603+
604+ # Reshape back to spatial dimensions
605+ x = x .reshape (bt , h , w , channels )
606+ x = self .proj (x )
607+
608+ # Reshape back to original shape
609+ x = x .reshape (batch_size , time , height , width , channels )
610+ return x + identity
611611
612612
613613class WanMidBlock (nnx .Module ):
@@ -626,18 +626,18 @@ def __init__(
626626 self .dim = dim
627627 self .resnets = nnx .List (
628628 [
629- WanResidualBlock (
630- in_dim = dim ,
631- out_dim = dim ,
632- rngs = rngs ,
633- dropout = dropout ,
634- non_linearity = non_linearity ,
635- mesh = mesh ,
636- dtype = dtype ,
637- weights_dtype = weights_dtype ,
638- precision = precision ,
639- )
640- ]
629+ WanResidualBlock (
630+ in_dim = dim ,
631+ out_dim = dim ,
632+ rngs = rngs ,
633+ dropout = dropout ,
634+ non_linearity = non_linearity ,
635+ mesh = mesh ,
636+ dtype = dtype ,
637+ weights_dtype = weights_dtype ,
638+ precision = precision ,
639+ )
640+ ]
641641 )
642642 self .attentions = nnx .List ([])
643643 for _ in range (num_layers ):
@@ -991,18 +991,18 @@ def __init__(
991991 upsample_mode = "upsample3d" if temperal_upsample [i ] else "upsample2d"
992992 self .up_blocks .append (
993993 WanUpBlock (
994- in_dim = in_dim ,
995- out_dim = out_dim ,
996- num_res_blocks = num_res_blocks ,
997- dropout = dropout ,
998- upsample_mode = upsample_mode ,
999- non_linearity = non_linearity ,
1000- rngs = rngs ,
1001- mesh = mesh ,
1002- dtype = dtype ,
1003- weights_dtype = weights_dtype ,
1004- precision = precision ,
1005- )
994+ in_dim = in_dim ,
995+ out_dim = out_dim ,
996+ num_res_blocks = num_res_blocks ,
997+ dropout = dropout ,
998+ upsample_mode = upsample_mode ,
999+ non_linearity = non_linearity ,
1000+ rngs = rngs ,
1001+ mesh = mesh ,
1002+ dtype = dtype ,
1003+ weights_dtype = weights_dtype ,
1004+ precision = precision ,
1005+ )
10061006 )
10071007
10081008 self .norm_out = WanRMS_norm (
0 commit comments