@@ -571,39 +571,43 @@ def __init__(
571571 precision = precision ,
572572 )
573573
574- def __call__ (self , x : jax .Array ):
575- identity = x
576- batch_size , time , height , width , channels = x .shape
577-
578- # Reshape to process all frames together
579- x = x .reshape (batch_size * time , height , width , channels )
580- x = self .norm (x )
581-
582- qkv = self .to_qkv (x ) # (B*T, H, W, C*3)
583- # Flatten spatial dimensions for attention
584- qkv = qkv .reshape (batch_size * time , - 1 , channels * 3 ) # (B*T, H*W, C*3)
585- qkv = jnp .transpose (qkv , (0 , 2 , 1 )) # (B*T, C*3, H*W)
586-
587- q , k , v = jnp .split (qkv , 3 , axis = 1 ) # Each: (B*T, C, H*W)
588- q = jnp .transpose (q , (0 , 2 , 1 )) # (B*T, H*W, C)
589- k = jnp .transpose (k , (0 , 2 , 1 )) # (B*T, H*W, C)
590- v = jnp .transpose (v , (0 , 2 , 1 )) # (B*T, H*W, C)
591-
592- # Add head dimension for dot_product_attention
593- q = jnp .expand_dims (q , 1 ) # (B*T, 1, H*W, C)
594- k = jnp .expand_dims (k , 1 ) # (B*T, 1, H*W, C)
595- v = jnp .expand_dims (v , 1 ) # (B*T, 1, H*W, C)
596-
597- x = jax .nn .dot_product_attention (q , k , v ) # (B*T, 1, H*W, C)
598- x = jnp .squeeze (x , 1 ) # (B*T, H*W, C)
599-
600- # Reshape back to spatial dimensions
601- x = x .reshape (batch_size * time , height , width , channels )
602- x = self .proj (x )
603-
604- # Reshape back to original shape
605- x = x .reshape (batch_size , time , height , width , channels )
606- return x + identity
574+ def __call__ (self , x : jax .Array ):
575+ identity = x
576+ batch_size , time , height , width , channels = x .shape
577+
578+ # Reshape to process all frames together
579+ x = x .reshape (batch_size * time , height , width , channels )
580+ x = self .norm (x )
581+
582+ qkv = self .to_qkv (x ) # (B*T, H, W, C*3)
583+
584+ # Get actual shape after to_qkv to avoid using stale variables
585+ bt , h , w , c3 = qkv .shape
586+
587+ # Flatten spatial dimensions for attention
588+ qkv = qkv .reshape (bt , h * w , c3 ) # (B*T, H*W, C*3)
589+ qkv = jnp .transpose (qkv , (0 , 2 , 1 )) # (B*T, C*3, H*W)
590+
591+ q , k , v = jnp .split (qkv , 3 , axis = 1 ) # Each: (B*T, C, H*W)
592+ q = jnp .transpose (q , (0 , 2 , 1 )) # (B*T, H*W, C)
593+ k = jnp .transpose (k , (0 , 2 , 1 )) # (B*T, H*W, C)
594+ v = jnp .transpose (v , (0 , 2 , 1 )) # (B*T, H*W, C)
595+
596+ # Add head dimension for dot_product_attention
597+ q = jnp .expand_dims (q , 1 ) # (B*T, 1, H*W, C)
598+ k = jnp .expand_dims (k , 1 ) # (B*T, 1, H*W, C)
599+ v = jnp .expand_dims (v , 1 ) # (B*T, 1, H*W, C)
600+
601+ x = jax .nn .dot_product_attention (q , k , v ) # (B*T, 1, H*W, C)
602+ x = jnp .squeeze (x , 1 ) # (B*T, H*W, C)
603+
604+ # Reshape back to spatial dimensions
605+ x = x .reshape (bt , h , w , channels )
606+ x = self .proj (x )
607+
608+ # Reshape back to original shape
609+ x = x .reshape (batch_size , time , height , width , channels )
610+ return x + identity
607611
608612
609613class WanMidBlock (nnx .Module ):
0 commit comments