@@ -99,7 +99,7 @@ def __init__(
9999
100100 def initialize_cache (self , batch_size , height , width , dtype ):
101101 cache = jnp .zeros (
102- (batch_size , CACHE_T , height , width , self .conv .in_features ), dtype = dtype
102+ (batch_size , CACHE_T , height , width , self .conv .in_features ), dtype = jnp . bfloat16
103103 )
104104
105105 # OPTIMIZATION: Spatial Partitioning on Initialization
@@ -139,6 +139,8 @@ def __call__(
139139 current_padding = list (self ._causal_padding )
140140
141141 if cache_x is not None :
142+ if cache_x .dtype != x .dtype :
143+ cache_x = cache_x .astype (x .dtype )
142144 x_concat = jnp .concatenate ([cache_x , x ], axis = 1 )
143145 new_cache = x_concat [:, - CACHE_T :, ...]
144146
@@ -162,6 +164,8 @@ def __call__(
162164 x_padded = x_input
163165
164166 out = self .conv (x_padded )
167+ new_cache = new_cache .astype (jnp .bfloat16 )
168+ print (f"Exiting WanCausalConv3d: { out .shape } " )
165169 return out , new_cache
166170
167171
@@ -189,11 +193,13 @@ def __init__(
189193 self .bias = 0
190194
191195 def __call__ (self , x : jax .Array ) -> jax .Array :
196+ print (f"Entering WanRMS_norm: { x .shape } " )
192197 normalized = jnp .linalg .norm (x , ord = 2 , axis = (1 if self .channel_first else - 1 ), keepdims = True )
193198 normalized = x / jnp .maximum (normalized , self .eps )
194199 normalized = normalized * self .scale * self .gamma
195200 if self .bias :
196201 return normalized + self .bias .value
202+ print (f"Exiting WanRMS_norm: { normalized .shape } " )
197203 return normalized
198204
199205
@@ -206,18 +212,23 @@ def __init__(self, scale_factor: Tuple[float, float], method: str = "nearest"):
206212 self .method = method
207213
208214 def __call__ (self , x : jax .Array ) -> jax .Array :
215+ print (f"Entering WanUpsample: { x .shape } " )
209216 input_dtype = x .dtype
210217 in_shape = x .shape
211218 assert len (in_shape ) == 4 , "This module only takes tensors with shape of 4."
212219 n , h , w , c = in_shape
213220 target_h = int (h * self .scale_factor [0 ])
214221 target_w = int (w * self .scale_factor [1 ])
215222 out = jax .image .resize (x .astype (jnp .float32 ), (n , target_h , target_w , c ), method = self .method )
216- return out .astype (input_dtype )
223+ out = out .astype (input_dtype )
224+ print (f"Exiting WanUpsample: { out .shape } " )
225+ return out
217226
218227
219228class Identity (nnx .Module ):
220229 def __call__ (self , x , cache = None ):
230+ print (f"Entering Identity: { x .shape } " )
231+ print (f"Exiting Identity: { x .shape } " )
221232 return x , cache
222233
223234
@@ -252,7 +263,10 @@ def __init__(
252263 )
253264
254265 def __call__ (self , x , cache = None ):
255- return self .conv (x ), cache
266+ print (f"Entering ZeroPaddedConv2D: { x .shape } " )
267+ out = self .conv (x )
268+ print (f"Exiting ZeroPaddedConv2D: { out .shape } " )
269+ return out , cache
256270
257271
258272class WanResample (nnx .Module ):
@@ -363,46 +377,59 @@ def initialize_cache(self, batch_size, height, width, dtype):
363377 def __call__ (
364378 self , x : jax .Array , cache : Dict [str , Any ] = None
365379 ) -> Tuple [jax .Array , Dict [str , Any ]]:
380+ print (f"Entering WanResample ({ self .mode } ): { x .shape } " )
366381 if cache is None :
367382 cache = {}
368383 new_cache = {}
369384
370385 if self .mode == "upsample2d" :
371386 b , t , h , w , c = x .shape
372387 x = x .reshape (b * t , h , w , c )
388+ print (f"WanResample ({ self .mode } ) reshaped for resample: { x .shape } " )
373389 x = self .resample (x )
390+ print (f"WanResample ({ self .mode } ) after resample: { x .shape } " )
374391 h_new , w_new , c_new = x .shape [1 :]
375392 x = x .reshape (b , t , h_new , w_new , c_new )
376393
377394 elif self .mode == "upsample3d" :
378395 x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
379396 new_cache ["time_conv" ] = tc_cache
397+ print (f"WanResample ({ self .mode } ) after time_conv: { x .shape } " )
380398
381399 b , t , h , w , c = x .shape
382400 x = x .reshape (b , t , h , w , 2 , c // 2 )
383401 x = jnp .stack ([x [:, :, :, :, 0 , :], x [:, :, :, :, 1 , :]], axis = 1 )
384402 x = x .reshape (b , t * 2 , h , w , c // 2 )
403+ print (f"WanResample ({ self .mode } ) after time dim expand: { x .shape } " )
404+
385405
386406 b , t , h , w , c = x .shape
387407 x = x .reshape (b * t , h , w , c )
408+ print (f"WanResample ({ self .mode } ) reshaped for resample: { x .shape } " )
388409 x = self .resample (x )
410+ print (f"WanResample ({ self .mode } ) after resample: { x .shape } " )
389411 h_new , w_new , c_new = x .shape [1 :]
390412 x = x .reshape (b , t , h_new , w_new , c_new )
391413
392414 elif self .mode == "downsample2d" :
393415 b , t , h , w , c = x .shape
394416 x = x .reshape (b * t , h , w , c )
417+ print (f"WanResample ({ self .mode } ) reshaped for resample: { x .shape } " )
395418 x , _ = self .resample (x , None )
419+ print (f"WanResample ({ self .mode } ) after resample: { x .shape } " )
396420 h_new , w_new , c_new = x .shape [1 :]
397421 x = x .reshape (b , t , h_new , w_new , c_new )
398422
399423 elif self .mode == "downsample3d" :
400424 x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
401425 new_cache ["time_conv" ] = tc_cache
426+ print (f"WanResample ({ self .mode } ) after time_conv: { x .shape } " )
402427
403428 b , t , h , w , c = x .shape
404429 x = x .reshape (b * t , h , w , c )
430+ print (f"WanResample ({ self .mode } ) reshaped for resample: { x .shape } " )
405431 x , _ = self .resample (x , None )
432+ print (f"WanResample ({ self .mode } ) after resample: { x .shape } " )
406433 h_new , w_new , c_new = x .shape [1 :]
407434 x = x .reshape (b , t , h_new , w_new , c_new )
408435
@@ -413,6 +440,7 @@ def __call__(
413440 else :
414441 x = self .resample (x )
415442
443+ print (f"Exiting WanResample ({ self .mode } ): { x .shape } " )
416444 return x , new_cache
417445
418446
@@ -486,26 +514,33 @@ def initialize_cache(self, batch_size, height, width, dtype):
486514 return cache
487515
488516 def __call__ (self , x : jax .Array , cache : Dict [str , Any ] = None ):
517+ print (f"Entering WanResidualBlock (in={ self .conv1 .conv .in_features } , out={ self .conv1 .conv .out_features } ): { x .shape } " )
489518 if cache is None :
490519 cache = {}
491520 new_cache = {}
492521
493522 h , sc_cache = self .conv_shortcut (x , cache .get ("shortcut" ))
494523 new_cache ["shortcut" ] = sc_cache
524+ print (f"WanResidualBlock after shortcut: { h .shape } " )
495525
496526 x = self .norm1 (x )
497527 x = self .nonlinearity (x )
528+ print (f"WanResidualBlock after norm1/nl: { x .shape } " )
498529
499530 x , c1 = self .conv1 (x , cache .get ("conv1" ))
500531 new_cache ["conv1" ] = c1
532+ print (f"WanResidualBlock after conv1: { x .shape } " )
501533
502534 x = self .norm2 (x )
503535 x = self .nonlinearity (x )
536+ print (f"WanResidualBlock after norm2/nl: { x .shape } " )
504537
505538 x , c2 = self .conv2 (x , cache .get ("conv2" ))
506539 new_cache ["conv2" ] = c2
540+ print (f"WanResidualBlock after conv2: { x .shape } " )
507541
508542 x = x + h
543+ print (f"Exiting WanResidualBlock: { x .shape } " )
509544 return x , new_cache
510545
511546
@@ -544,15 +579,17 @@ def __init__(
544579 )
545580
546581 def __call__ (self , x : jax .Array ):
547- jax . debug . print ("AttentionBlock input shape : {shape}" , shape = x .shape )
582+ print (f"Entering WanAttentionBlock : { x .shape } " )
548583 identity = x
549584 batch_size , time , height , width , channels = x .shape
550585
551586 x = x .reshape (batch_size * time , height , width , channels )
587+ print (f"WanAttentionBlock reshaped for norm: { x .shape } " )
552588 x = self .norm (x )
553589
554590 qkv = self .to_qkv (x ) # Output: (N*D, H, W, C * 3)
555591 # qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
592+ print (f"WanAttentionBlock qkv shape: { qkv .shape } " )
556593 qkv = qkv .reshape (batch_size * time , 1 , - 1 , channels * 3 )
557594 qkv = jnp .transpose (qkv , (0 , 1 , 3 , 2 ))
558595 q , k , v = jnp .split (qkv , 3 , axis = - 2 )
@@ -566,8 +603,9 @@ def __call__(self, x: jax.Array):
566603 x = self .proj (x )
567604 # Reshape back
568605 x = x .reshape (batch_size , time , height , width , channels )
569-
570- return x + identity
606+ out = x + identity
607+ print (f"Exiting WanAttentionBlock: { out .shape } " )
608+ return out
571609
572610
573611class WanMidBlock (nnx .Module ):
@@ -635,23 +673,26 @@ def initialize_cache(self, batch_size, height, width, dtype):
635673 return cache
636674
637675 def __call__ (self , x : jax .Array , cache : Dict [str , Any ] = None ):
676+ print (f"Entering WanMidBlock: { x .shape } " )
638677 if cache is None :
639678 cache = {}
640679 new_cache = {"resnets" : []}
641- jax .debug .print ("MidBlock input shape: {shape}" , shape = x .shape )
642680
643681 x , c = self .resnets [0 ](x , cache .get ("resnets" , [None ])[0 ])
644682 new_cache ["resnets" ].append (c )
645- jax . debug . print ("MidBlock after resnets[0] shape : {shape}" , shape = x .shape )
683+ print (f"WanMidBlock after resnets[0]: { x .shape } " )
646684
647685 for i , (attn , resnet ) in enumerate (zip (self .attentions , self .resnets [1 :])):
648686 if attn is not None :
649- jax . debug . print ("MidBlock before attn {i}: {shape}" , i = i , shape = x .shape )
687+ print (f"WanMidBlock before attn { i } : { x .shape } " )
650688 x = attn (x )
651- jax .debug .print ("MidBlock after attn {i}: {shape}" , i = i , shape = x .shape )
689+ print (f"WanMidBlock after attn { i } : { x .shape } " )
690+ print (f"WanMidBlock before resnets[{ i + 1 } ]: { x .shape } " )
652691 x , c = resnet (x , cache .get ("resnets" , [None ] * len (self .resnets ))[i + 1 ])
653692 new_cache ["resnets" ].append (c )
693+ print (f"WanMidBlock after resnets[{ i + 1 } ]: { x .shape } " )
654694
695+ print (f"Exiting WanMidBlock: { x .shape } " )
655696 return x , new_cache
656697
657698
@@ -869,38 +910,41 @@ def init_cache(self, batch_size, height, width, dtype):
869910 return cache
870911
871912 def __call__ (self , x : jax .Array , cache : Dict [str , Any ] = None ):
913+ print (f"Entering WanEncoder3d: { x .shape } " )
872914 if cache is None :
873915 cache = {}
874916 new_cache = {}
875- jax . debug . print ( "Encoder input shape: {shape}" , shape = x . shape )
917+
876918 x , c = self .conv_in (x , cache .get ("conv_in" ))
877919 new_cache ["conv_in" ] = c
878- jax . debug . print ("Encoder after conv_in shape : {shape}" , shape = x .shape )
920+ print (f"WanEncoder3d after conv_in: { x .shape } " )
879921
880922 new_cache ["down_blocks" ] = []
881923 current_down_caches = cache .get ("down_blocks" , [None ] * len (self .down_blocks ))
882924
883925 for i , layer in enumerate (self .down_blocks ):
884- jax . debug . print ("Encoder before down_block {i} (" + type (layer ).__name__ + " ): {shape}" , i = i , shape = x .shape )
926+ print (f"WanEncoder3d before down_block { i } ({ type (layer ).__name__ } ): { x .shape } " )
885927 if isinstance (layer , (WanResidualBlock , WanResample )):
886928 x , c = layer (x , current_down_caches [i ])
887929 new_cache ["down_blocks" ].append (c )
888930 else :
889931 x = layer (x )
890932 new_cache ["down_blocks" ].append (None )
891- jax . debug . print ("Encoder after down_block {i} (" + type ( layer ). __name__ + ") : {shape}" , i = i , shape = x .shape )
933+ print (f"WanEncoder3d after down_block { i } : { x .shape } " )
892934
893935
894- jax .debug .print ("Encoder before mid_block: {shape}" , shape = x .shape )
895936 x , c = self .mid_block (x , cache .get ("mid_block" ))
896937 new_cache ["mid_block" ] = c
897- jax . debug . print ("Encoder after mid_block: {shape}" , shape = x .shape )
938+ print (f"WanEncoder3d after mid_block: { x .shape } " )
898939
899940 x = self .norm_out (x )
941+ print (f"WanEncoder3d after norm_out: { x .shape } " )
900942 x = self .nonlinearity (x )
943+ print (f"WanEncoder3d after nonlinearity: { x .shape } " )
901944
902945 x , c = self .conv_out (x , cache .get ("conv_out" ))
903946 new_cache ["conv_out" ] = c
947+ print (f"Exiting WanEncoder3d: { x .shape } " )
904948
905949 return x , new_cache
906950
@@ -1203,7 +1247,7 @@ def __init__(
12031247 precision = precision ,
12041248 )
12051249
1206- @nnx .jit
1250+ # @nnx.jit
12071251 def encode (
12081252 self , x : jax .Array , return_dict : bool = True
12091253 ) -> Union [FlaxAutoencoderKLOutput , Tuple [FlaxDiagonalGaussianDistribution ]]:
@@ -1270,7 +1314,7 @@ def scan_fn_chunk(carry, input_slice):
12701314 return (posterior ,)
12711315 return FlaxAutoencoderKLOutput (latent_dist = posterior )
12721316
1273- @nnx .jit
1317+ # @nnx.jit
12741318 def decode (
12751319 self , z : jax .Array , return_dict : bool = True
12761320 ) -> Union [FlaxDecoderOutput , jax .Array ]:
0 commit comments