@@ -177,26 +177,35 @@ class WanResample(nnx.Module):
177177 def __init__ (self , dim : int , mode : str , rngs : nnx .Rngs , mesh : jax .sharding .Mesh = None , dtype : jnp .dtype = jnp .float32 , weights_dtype : jnp .dtype = jnp .float32 , precision : jax .lax .Precision = None ):
178178 self .dim = dim
179179 self .mode = mode
180-
180+
181181 if mode == "upsample2d" :
182- self .upsample = WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" )
183- self .conv = nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
182+ # FIX: Use Sequential to match checkpoint keys
183+ self .resample = nnx .Sequential (
184+ WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" ),
185+ nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
186+ )
184187 elif mode == "upsample3d" :
188+ # 3D mode uses explicit attributes for cache handling
185189 self .upsample = WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" )
186190 self .conv = nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
187191 self .time_conv = WanCausalConv3d (rngs = rngs , in_channels = dim , out_channels = dim * 2 , kernel_size = (3 , 1 , 1 ), padding = (1 , 0 , 0 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
188192 elif mode == "downsample2d" :
189- self .downsample_conv = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (3 , 3 ), stride = (2 , 2 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
193+ # FIX: Use Sequential/Wrapper to match checkpoint keys if needed,
194+ # but ZeroPaddedConv2D is a Module itself, so direct assignment is likely fine unless checkpoint wrapped it.
195+ # Based on error log, downsample keys were missing too.
196+ self .resample = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (3 , 3 ), stride = (2 , 2 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
190197 elif mode == "downsample3d" :
198+ # 3D mode explicit
191199 self .downsample_conv = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (3 , 3 ), stride = (2 , 2 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
192200 self .time_conv = WanCausalConv3d (rngs = rngs , in_channels = dim , out_channels = dim , kernel_size = (3 , 1 , 1 ), stride = (2 , 1 , 1 ), padding = (0 , 0 , 0 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
201+ else :
202+ self .resample = Identity ()
193203
194204 def initialize_cache (self , batch_size , height , width , dtype ):
195205 cache = {}
196206 if hasattr (self , "time_conv" ):
197207 h_curr , w_curr = height , width
198208 if self .mode == "downsample3d" :
199- # Resample (stride 2) happens before time conv
200209 h_curr , w_curr = height // 2 , width // 2
201210 cache ["time_conv" ] = self .time_conv .initialize_cache (batch_size , h_curr , w_curr , dtype )
202211 return cache
@@ -206,25 +215,22 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
206215 new_cache = {}
207216
208217 if self .mode == "upsample2d" :
218+ # Use self.resample (Sequential)
209219 b , t , h , w , c = x .shape
210220 x = x .reshape (b * t , h , w , c )
211- x = self .upsample (x )
212- x = self .conv (x )
221+ x = self .resample (x )
213222 h_new , w_new , c_new = x .shape [1 :]
214223 x = x .reshape (b , t , h_new , w_new , c_new )
215224
216225 elif self .mode == "upsample3d" :
217- # TimeConv -> Reshape -> Resample
218226 x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
219227 new_cache ["time_conv" ] = tc_cache
220228
221229 b , t , h , w , c = x .shape
222- # Split channels for time upsample
223230 x = x .reshape (b , t , h , w , 2 , c // 2 )
224231 x = jnp .stack ([x [:, :, :, :, 0 , :], x [:, :, :, :, 1 , :]], axis = 1 )
225232 x = x .reshape (b , t * 2 , h , w , c // 2 )
226233
227- # Spatial resample
228234 b , t , h , w , c = x .shape
229235 x = x .reshape (b * t , h , w , c )
230236 x = self .upsample (x )
@@ -233,9 +239,10 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
233239 x = x .reshape (b , t , h_new , w_new , c_new )
234240
235241 elif self .mode == "downsample2d" :
242+ # Use self.resample (ZeroPaddedConv2D)
236243 b , t , h , w , c = x .shape
237244 x = x .reshape (b * t , h , w , c )
238- x , _ = self .downsample_conv (x , None )
245+ x , _ = self .resample (x , None )
239246 h_new , w_new , c_new = x .shape [1 :]
240247 x = x .reshape (b , t , h_new , w_new , c_new )
241248
@@ -245,14 +252,15 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
245252 x , _ = self .downsample_conv (x , None )
246253 h_new , w_new , c_new = x .shape [1 :]
247254 x = x .reshape (b , t , h_new , w_new , c_new )
248-
255+
249256 x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
250257 new_cache ["time_conv" ] = tc_cache
258+
251259 else :
252260 if hasattr (self , "resample" ):
253- x , _ = self .resample (x , None )
254- return x , new_cache
261+ x , _ = self .resample (x , None )
255262
263+ return x , new_cache
256264
257265class WanResidualBlock (nnx .Module ):
258266 def __init__ (self , in_dim : int , out_dim : int , rngs : nnx .Rngs , dropout : float = 0.0 , non_linearity : str = "silu" , mesh : jax .sharding .Mesh = None , dtype : jnp .dtype = jnp .float32 , weights_dtype : jnp .dtype = jnp .float32 , precision : jax .lax .Precision = None ):
0 commit comments