@@ -421,18 +421,29 @@ def __call__(
421421 x = x .reshape (b , t , h_new , w_new , c_new )
422422
423423 elif self .mode == "downsample3d" :
424- x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
425- new_cache ["time_conv" ] = tc_cache
426- print (f"WanResample ({ self .mode } ) after time_conv: { x .shape } " )
424+ if x .shape [1 ] >= self .time_conv .kernel_size [0 ]:
425+ x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
426+ new_cache ["time_conv" ] = tc_cache
427+ print (f"WanResample ({ self .mode } ) after time_conv: { x .shape } " )
428+ else :
429+ # Skip temporal downsampling if not enough frames
430+ print (f"WanResample ({ self .mode } ): Skipping time_conv, input time dim { x .shape [1 ]} < kernel { self .time_conv .kernel_size [0 ]} " )
431+ new_cache ["time_conv" ] = cache .get ("time_conv" ) # Pass through cache
427432
428433 b , t , h , w , c = x .shape
429- x = x .reshape (b * t , h , w , c )
430- print (f"WanResample ({ self .mode } ) reshaped for resample: { x .shape } " )
431- x , _ = self .resample (x , None )
432- print (f"WanResample ({ self .mode } ) after resample: { x .shape } " )
433- h_new , w_new , c_new = x .shape [1 :]
434- x = x .reshape (b , t , h_new , w_new , c_new )
435-
434+ if b * t > 0 :
435+ x = x .reshape (b * t , h , w , c )
436+ print (f"WanResample ({ self .mode } ) reshaped for resample: { x .shape } " )
437+ x , _ = self .resample (x , None )
438+ print (f"WanResample ({ self .mode } ) after resample: { x .shape } " )
439+ h_new , w_new , c_new = x .shape [1 :]
440+ x = x .reshape (b , t , h_new , w_new , c_new )
441+ else :
442+ # If time dimension became 0, spatial shape changes, but batch and time are still 0
443+ h_new , w_new = h // self .resample .conv .strides [0 ], w // self .resample .conv .strides [1 ]
444+ c_new = self .resample .conv .out_features
445+ x = jnp .zeros ((b , t , h_new , w_new , c_new ), dtype = x .dtype )
446+ print (f"WanResample ({ self .mode } ): Spatial downsample output shape { x .shape } (due to t=0)" )
436447 else :
437448 if hasattr (self , "resample" ):
438449 if isinstance (self .resample , Identity ):
0 commit comments