@@ -166,7 +166,6 @@ def __call__(
166166
167167 out = self .conv (x_padded )
168168 new_cache = new_cache .astype (jnp .bfloat16 )
169- print (f"Exiting WanCausalConv3d: { out .shape } " )
170169 return out , new_cache
171170
172171
@@ -194,13 +193,11 @@ def __init__(
194193 self .bias = 0
195194
196195 def __call__ (self , x : jax .Array ) -> jax .Array :
197- print (f"Entering WanRMS_norm: { x .shape } " )
198196 normalized = jnp .linalg .norm (x , ord = 2 , axis = (1 if self .channel_first else - 1 ), keepdims = True )
199197 normalized = x / jnp .maximum (normalized , self .eps )
200198 normalized = normalized * self .scale * self .gamma
201199 if self .bias :
202200 return normalized + self .bias .value
203- print (f"Exiting WanRMS_norm: { normalized .shape } " )
204201 return normalized
205202
206203
@@ -213,7 +210,6 @@ def __init__(self, scale_factor: Tuple[float, float], method: str = "nearest"):
213210 self .method = method
214211
215212 def __call__ (self , x : jax .Array ) -> jax .Array :
216- print (f"Entering WanUpsample: { x .shape } " )
217213 input_dtype = x .dtype
218214 in_shape = x .shape
219215 assert len (in_shape ) == 4 , "This module only takes tensors with shape of 4."
@@ -222,14 +218,11 @@ def __call__(self, x: jax.Array) -> jax.Array:
222218 target_w = int (w * self .scale_factor [1 ])
223219 out = jax .image .resize (x .astype (jnp .float32 ), (n , target_h , target_w , c ), method = self .method )
224220 out = out .astype (input_dtype )
225- print (f"Exiting WanUpsample: { out .shape } " )
226221 return out
227222
228223
229224class Identity (nnx .Module ):
230225 def __call__ (self , x , cache = None ):
231- print (f"Entering Identity: { x .shape } " )
232- print (f"Exiting Identity: { x .shape } " )
233226 return x , cache
234227
235228
@@ -264,9 +257,7 @@ def __init__(
264257 )
265258
266259 def __call__ (self , x , cache = None ):
267- print (f"Entering ZeroPaddedConv2D: { x .shape } " )
268260 out = self .conv (x )
269- print (f"Exiting ZeroPaddedConv2D: { out .shape } " )
270261 return out , cache
271262
272263
@@ -378,81 +369,65 @@ def initialize_cache(self, batch_size, height, width, dtype):
378369 def __call__ (
379370 self , x : jax .Array , cache : Dict [str , Any ] = None
380371 ) -> Tuple [jax .Array , Dict [str , Any ]]:
381- print (f"Entering WanResample ({ self .mode } ): { x .shape } " )
382372 if cache is None :
383373 cache = {}
384374 new_cache = {}
385375
386376 if self .mode == "upsample2d" :
387377 b , t , h , w , c = x .shape
388378 x = x .reshape (b * t , h , w , c )
389- print (f"WanResample ({ self .mode } ) reshaped for resample: { x .shape } " )
390379 x = self .resample (x )
391- print (f"WanResample ({ self .mode } ) after resample: { x .shape } " )
392380 h_new , w_new , c_new = x .shape [1 :]
393381 x = x .reshape (b , t , h_new , w_new , c_new )
394382
395383 elif self .mode == "upsample3d" :
396384 x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
397385 new_cache ["time_conv" ] = tc_cache
398- print (f"WanResample ({ self .mode } ) after time_conv: { x .shape } " )
399386
400387 b , t , h , w , c = x .shape
401388 x = x .reshape (b , t , h , w , 2 , c // 2 )
402389 x = jnp .stack ([x [:, :, :, :, 0 , :], x [:, :, :, :, 1 , :]], axis = 1 )
403390 x = x .reshape (b , t * 2 , h , w , c // 2 )
404- print (f"WanResample ({ self .mode } ) after time dim expand: { x .shape } " )
405-
406391
407392 b , t , h , w , c = x .shape
408393 x = x .reshape (b * t , h , w , c )
409- print (f"WanResample ({ self .mode } ) reshaped for resample: { x .shape } " )
410394 x = self .resample (x )
411- print (f"WanResample ({ self .mode } ) after resample: { x .shape } " )
412395 h_new , w_new , c_new = x .shape [1 :]
413396 x = x .reshape (b , t , h_new , w_new , c_new )
414397
415398 elif self .mode == "downsample2d" :
416399 b , t , h , w , c = x .shape
417400 x = x .reshape (b * t , h , w , c )
418- print (f"WanResample ({ self .mode } ) reshaped for resample: { x .shape } " )
419401 x , _ = self .resample (x , None )
420- print (f"WanResample ({ self .mode } ) after resample: { x .shape } " )
421402 h_new , w_new , c_new = x .shape [1 :]
422403 x = x .reshape (b , t , h_new , w_new , c_new )
423404
424405 elif self .mode == "downsample3d" :
425406 if x .shape [1 ] >= self .time_conv .kernel_size [0 ]:
426407 x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
427408 new_cache ["time_conv" ] = tc_cache
428- print (f"WanResample ({ self .mode } ) after time_conv: { x .shape } " )
429409 else :
430410 # Skip temporal downsampling if not enough frames
431- print (f"WanResample ({ self .mode } ): Skipping time_conv, input time dim { x .shape [1 ]} < kernel { self .time_conv .kernel_size [0 ]} " )
432411 new_cache ["time_conv" ] = cache .get ("time_conv" ) # Pass through cache
433412
434413 b , t , h , w , c = x .shape
435414 if b * t > 0 :
436415 x = x .reshape (b * t , h , w , c )
437- print (f"WanResample ({ self .mode } ) reshaped for resample: { x .shape } " )
438416 x , _ = self .resample (x , None )
439- print (f"WanResample ({ self .mode } ) after resample: { x .shape } " )
440417 h_new , w_new , c_new = x .shape [1 :]
441418 x = x .reshape (b , t , h_new , w_new , c_new )
442419 else :
443420 # If time dimension became 0, spatial shape changes, but batch and time are still 0
444421 h_new , w_new = h // self .resample .conv .strides [0 ], w // self .resample .conv .strides [1 ]
445422 c_new = self .resample .conv .out_features
446423 x = jnp .zeros ((b , t , h_new , w_new , c_new ), dtype = x .dtype )
447- print (f"WanResample ({ self .mode } ): Spatial downsample output shape { x .shape } (due to t=0)" )
448424 else :
449425 if hasattr (self , "resample" ):
450426 if isinstance (self .resample , Identity ):
451427 x , _ = self .resample (x , None )
452428 else :
453429 x = self .resample (x )
454430
455- print (f"Exiting WanResample ({ self .mode } ): { x .shape } " )
456431 return x , new_cache
457432
458433
@@ -526,33 +501,26 @@ def initialize_cache(self, batch_size, height, width, dtype):
526501 return cache
527502
528503 def __call__ (self , x : jax .Array , cache : Dict [str , Any ] = None ):
529- print (f"Entering WanResidualBlock (in={ self .conv1 .conv .in_features } , out={ self .conv1 .conv .out_features } ): { x .shape } " )
530504 if cache is None :
531505 cache = {}
532506 new_cache = {}
533507
534508 h , sc_cache = self .conv_shortcut (x , cache .get ("shortcut" ))
535509 new_cache ["shortcut" ] = sc_cache
536- print (f"WanResidualBlock after shortcut: { h .shape } " )
537510
538511 x = self .norm1 (x )
539512 x = self .nonlinearity (x )
540- print (f"WanResidualBlock after norm1/nl: { x .shape } " )
541513
542514 x , c1 = self .conv1 (x , cache .get ("conv1" ))
543515 new_cache ["conv1" ] = c1
544- print (f"WanResidualBlock after conv1: { x .shape } " )
545516
546517 x = self .norm2 (x )
547518 x = self .nonlinearity (x )
548- print (f"WanResidualBlock after norm2/nl: { x .shape } " )
549519
550520 x , c2 = self .conv2 (x , cache .get ("conv2" ))
551521 new_cache ["conv2" ] = c2
552- print (f"WanResidualBlock after conv2: { x .shape } " )
553522
554523 x = x + h
555- print (f"Exiting WanResidualBlock: { x .shape } " )
556524 return x , new_cache
557525
558526
@@ -591,17 +559,14 @@ def __init__(
591559 )
592560
593561 def __call__ (self , x : jax .Array ):
594- print (f"Entering WanAttentionBlock: { x .shape } " )
595562 identity = x
596563 batch_size , time , height , width , channels = x .shape
597564
598565 x = x .reshape (batch_size * time , height , width , channels )
599- print (f"WanAttentionBlock reshaped for norm: { x .shape } " )
600566 x = self .norm (x )
601567
602568 qkv = self .to_qkv (x ) # Output: (N*D, H, W, C * 3)
603569 # qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
604- print (f"WanAttentionBlock qkv shape: { qkv .shape } " )
605570 qkv = qkv .reshape (batch_size * time , 1 , - 1 , channels * 3 )
606571 qkv = jnp .transpose (qkv , (0 , 1 , 3 , 2 ))
607572 q , k , v = jnp .split (qkv , 3 , axis = - 2 )
@@ -616,7 +581,6 @@ def __call__(self, x: jax.Array):
616581 # Reshape back
617582 x = x .reshape (batch_size , time , height , width , channels )
618583 out = x + identity
619- print (f"Exiting WanAttentionBlock: { out .shape } " )
620584 return out
621585
622586
@@ -685,26 +649,19 @@ def initialize_cache(self, batch_size, height, width, dtype):
685649 return cache
686650
687651 def __call__ (self , x : jax .Array , cache : Dict [str , Any ] = None ):
688- print (f"Entering WanMidBlock: { x .shape } " )
689652 if cache is None :
690653 cache = {}
691654 new_cache = {"resnets" : []}
692655
693656 x , c = self .resnets [0 ](x , cache .get ("resnets" , [None ])[0 ])
694657 new_cache ["resnets" ].append (c )
695- print (f"WanMidBlock after resnets[0]: { x .shape } " )
696658
697659 for i , (attn , resnet ) in enumerate (zip (self .attentions , self .resnets [1 :])):
698660 if attn is not None :
699- print (f"WanMidBlock before attn { i } : { x .shape } " )
700661 x = attn (x )
701- print (f"WanMidBlock after attn { i } : { x .shape } " )
702- print (f"WanMidBlock before resnets[{ i + 1 } ]: { x .shape } " )
703662 x , c = resnet (x , cache .get ("resnets" , [None ] * len (self .resnets ))[i + 1 ])
704663 new_cache ["resnets" ].append (c )
705- print (f"WanMidBlock after resnets[{ i + 1 } ]: { x .shape } " )
706664
707- print (f"Exiting WanMidBlock: { x .shape } " )
708665 return x , new_cache
709666
710667
@@ -922,41 +879,32 @@ def init_cache(self, batch_size, height, width, dtype):
922879 return cache
923880
924881 def __call__ (self , x : jax .Array , cache : Dict [str , Any ] = None ):
925- print (f"Entering WanEncoder3d: { x .shape } " )
926882 if cache is None :
927883 cache = {}
928884 new_cache = {}
929885
930886 x , c = self .conv_in (x , cache .get ("conv_in" ))
931887 new_cache ["conv_in" ] = c
932- print (f"WanEncoder3d after conv_in: { x .shape } " )
933888
934889 new_cache ["down_blocks" ] = []
935890 current_down_caches = cache .get ("down_blocks" , [None ] * len (self .down_blocks ))
936891
937892 for i , layer in enumerate (self .down_blocks ):
938- print (f"WanEncoder3d before down_block { i } ({ type (layer ).__name__ } ): { x .shape } " )
939893 if isinstance (layer , (WanResidualBlock , WanResample )):
940894 x , c = layer (x , current_down_caches [i ])
941895 new_cache ["down_blocks" ].append (c )
942896 else :
943897 x = layer (x )
944898 new_cache ["down_blocks" ].append (None )
945- print (f"WanEncoder3d after down_block { i } : { x .shape } " )
946-
947899
948900 x , c = self .mid_block (x , cache .get ("mid_block" ))
949901 new_cache ["mid_block" ] = c
950- print (f"WanEncoder3d after mid_block: { x .shape } " )
951902
952903 x = self .norm_out (x )
953- print (f"WanEncoder3d after norm_out: { x .shape } " )
954904 x = self .nonlinearity (x )
955- print (f"WanEncoder3d after nonlinearity: { x .shape } " )
956905
957906 x , c = self .conv_out (x , cache .get ("conv_out" ))
958907 new_cache ["conv_out" ] = c
959- print (f"Exiting WanEncoder3d: { x .shape } " )
960908
961909 return x , new_cache
962910
@@ -1270,8 +1218,7 @@ def _encode_jit(self, x: jax.Array) -> jax.Array:
12701218 # Process the first frame (Time=1)
12711219 x_first = x [:, :1 , ...]
12721220 init_cache_first = self .encoder .init_cache (b , h , w , x_first .dtype )
1273- encoder_checkpointed = jax .checkpoint (self .encoder )
1274- out1 , state_carry = encoder_checkpointed (x_first , init_cache_first )
1221+ out1 , state_carry = self .encoder (x_first , init_cache_first )
12751222 all_outs .append (out1 )
12761223
12771224 # Process the remaining frames using scan over chunks of 4
@@ -1295,7 +1242,7 @@ def _encode_jit(self, x: jax.Array) -> jax.Array:
12951242 x_scannable = jnp .swapaxes (x_reshaped , 0 , 1 )
12961243
12971244 def scan_fn (carry_state , x_chunk ):
1298- out_chunk , new_state = encoder_checkpointed (x_chunk , carry_state )
1245+ out_chunk , new_state = self . encoder (x_chunk , carry_state )
12991246 return new_state , out_chunk
13001247
13011248 _ , encoded_chunks = jax .lax .scan (scan_fn , state_carry , x_scannable )
@@ -1331,19 +1278,18 @@ def _decode_jit(self, z: jax.Array) -> jax.Array:
13311278
13321279 b , t , h , w , c = x .shape
13331280 init_cache = self .decoder .init_cache (b , h , w , x .dtype )
1334- decoder_checkpointed = jax .checkpoint (self .decoder )
13351281
13361282 all_decoded = []
13371283 x_first = x [:, :1 , ...]
1338- out_first , state_carry = decoder_checkpointed (x_first , init_cache )
1284+ out_first , state_carry = self . decoder (x_first , init_cache )
13391285 all_decoded .append (out_first )
13401286 if t > 1 :
13411287 x_rest = x [:, 1 :, ...]
13421288 x_scan = jnp .swapaxes (x_rest , 0 , 1 )
13431289
13441290 def scan_fn (carry , input_slice ):
13451291 input_slice = jnp .expand_dims (input_slice , 1 )
1346- out_slice , new_carry = decoder_checkpointed (input_slice , carry )
1292+ out_slice , new_carry = self . decoder (input_slice , carry )
13471293 out_swapped = out_slice [:, jnp .array ([0 , 2 , 1 , 3 ]), ...]
13481294
13491295 return new_carry , out_swapped
0 commit comments