@@ -544,7 +544,7 @@ def __init__(
544544 )
545545
546546 def __call__ (self , x : jax .Array ):
547-
547+ jax . debug . print ( "AttentionBlock input shape: {shape}" , shape = x . shape )
548548 identity = x
549549 batch_size , time , height , width , channels = x .shape
550550
@@ -638,13 +638,17 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
638638 if cache is None :
639639 cache = {}
640640 new_cache = {"resnets" : []}
641+ jax .debug .print ("MidBlock input shape: {shape}" , shape = x .shape )
641642
642643 x , c = self .resnets [0 ](x , cache .get ("resnets" , [None ])[0 ])
643644 new_cache ["resnets" ].append (c )
645+ jax .debug .print ("MidBlock after resnets[0] shape: {shape}" , shape = x .shape )
644646
645647 for i , (attn , resnet ) in enumerate (zip (self .attentions , self .resnets [1 :])):
646648 if attn is not None :
649+ jax .debug .print ("MidBlock before attn {i}: {shape}" , shape = x .shape )
647650 x = attn (x )
651+ jax .debug .print ("MidBlock after attn {i}: {shape}" , shape = x .shape )
648652 x , c = resnet (x , cache .get ("resnets" , [None ] * len (self .resnets ))[i + 1 ])
649653 new_cache ["resnets" ].append (c )
650654
@@ -868,23 +872,28 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
868872 if cache is None :
869873 cache = {}
870874 new_cache = {}
871-
875+ jax . debug . print ( "Encoder input shape: {shape}" , shape = x . shape )
872876 x , c = self .conv_in (x , cache .get ("conv_in" ))
873877 new_cache ["conv_in" ] = c
878+ jax .debug .print ("Encoder after conv_in shape: {shape}" , shape = x .shape )
874879
875880 new_cache ["down_blocks" ] = []
876881 current_down_caches = cache .get ("down_blocks" , [None ] * len (self .down_blocks ))
877882
878883 for i , layer in enumerate (self .down_blocks ):
884+ jax .debug .print (f"Encoder before down_block { i } ({ type (layer ).__name__ } ): {{shape}}" , shape = x .shape )
879885 if isinstance (layer , (WanResidualBlock , WanResample )):
880886 x , c = layer (x , current_down_caches [i ])
881887 new_cache ["down_blocks" ].append (c )
882888 else :
883889 x = layer (x )
884890 new_cache ["down_blocks" ].append (None )
891+ jax .debug .print (f"Encoder after down_block { i } : {{shape}}" , shape = x .shape )
885892
893+ jax .debug .print ("Encoder before mid_block: {shape}" , shape = x .shape )
886894 x , c = self .mid_block (x , cache .get ("mid_block" ))
887895 new_cache ["mid_block" ] = c
896+ jax .debug .print ("Encoder after mid_block: {shape}" , shape = x .shape )
888897
889898 x = self .norm_out (x )
890899 x = self .nonlinearity (x )
@@ -1193,7 +1202,7 @@ def __init__(
11931202 precision = precision ,
11941203 )
11951204
1196- # @nnx.jit
1205+ @nnx .jit
11971206 def encode (
11981207 self , x : jax .Array , return_dict : bool = True
11991208 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
0 commit comments