@@ -646,9 +646,9 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
646646
647647 for i , (attn , resnet ) in enumerate (zip (self .attentions , self .resnets [1 :])):
648648 if attn is not None :
649- jax .debug .print ("MidBlock before attn {i}: {shape}" , shape = x .shape )
649+ jax .debug .print ("MidBlock before attn {i}: {shape}" , i = i , shape = x .shape )
650650 x = attn (x )
651- jax .debug .print ("MidBlock after attn {i}: {shape}" , shape = x .shape )
651+ jax .debug .print ("MidBlock after attn {i}: {shape}" , i = i , shape = x .shape )
652652 x , c = resnet (x , cache .get ("resnets" , [None ] * len (self .resnets ))[i + 1 ])
653653 new_cache ["resnets" ].append (c )
654654
@@ -881,14 +881,14 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
881881 current_down_caches = cache .get ("down_blocks" , [None ] * len (self .down_blocks ))
882882
883883 for i , layer in enumerate (self .down_blocks ):
884- jax .debug .print (f"Encoder before down_block { i } ({ type (layer ).__name__ } ): {{shape}}" , shape = x .shape )
884+ jax .debug .print (f"Encoder before down_block { i } ({ type (layer ).__name__ } ): {{shape}}" , i = i , shape = x .shape )
885885 if isinstance (layer , (WanResidualBlock , WanResample )):
886886 x , c = layer (x , current_down_caches [i ])
887887 new_cache ["down_blocks" ].append (c )
888888 else :
889889 x = layer (x )
890890 new_cache ["down_blocks" ].append (None )
891- jax .debug .print (f"Encoder after down_block { i } : {{shape}}" , shape = x .shape )
891+ jax .debug .print (f"Encoder after down_block { i } : {{shape}}" , i = i , shape = x .shape )
892892
893893 jax .debug .print ("Encoder before mid_block: {shape}" , shape = x .shape )
894894 x , c = self .mid_block (x , cache .get ("mid_block" ))
0 commit comments