@@ -178,26 +178,102 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
178178 self .dim = dim
179179 self .mode = mode
180180
181+ # ATTRIBUTES MUST BE DEFINED IN THE INIT PATH
182+ # We use different attribute names depending on mode to match strict checkpoint keys if needed,
183+ # OR we rely on the fact that the checkpoint loading mapping handles the name translation.
184+ # based on the error, the checkpoint expects 'resample' to be a Sequential for 2D modes.
185+
181186 if mode == "upsample2d" :
182- # FIX: Use Sequential to match checkpoint keys
187+ # Map: resample.layers.0 -> WanUpsample
188+ # Map: resample.layers.1 -> Conv
183189 self .resample = nnx .Sequential (
184190 WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" ),
185- nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
191+ nnx .Conv (
192+ dim ,
193+ dim // 2 ,
194+ kernel_size = (3 , 3 ),
195+ padding = "SAME" ,
196+ use_bias = True ,
197+ rngs = rngs ,
198+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )),
199+ dtype = dtype ,
200+ param_dtype = weights_dtype ,
201+ precision = precision
202+ )
186203 )
204+
187205 elif mode == "upsample3d" :
188- # 3D mode uses explicit attributes for cache handling
206+ # 3D mode: Code handles 'upsample' and 'conv' separately in __call__,
207+ # BUT for checkpoint loading, if the checkpoint has 'resample.layers...',
208+ # we might need to match that.
209+ # However, standard Wan3D usually has explicit components.
210+ # We will stick to explicit attributes here as defined in previous working versions.
189211 self .upsample = WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" )
190- self .conv = nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
191- self .time_conv = WanCausalConv3d (rngs = rngs , in_channels = dim , out_channels = dim * 2 , kernel_size = (3 , 1 , 1 ), padding = (1 , 0 , 0 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
212+ self .conv = nnx .Conv (
213+ dim ,
214+ dim // 2 ,
215+ kernel_size = (3 , 3 ),
216+ padding = "SAME" ,
217+ use_bias = True ,
218+ rngs = rngs ,
219+ kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )),
220+ dtype = dtype ,
221+ param_dtype = weights_dtype ,
222+ precision = precision
223+ )
224+ self .time_conv = WanCausalConv3d (
225+ rngs = rngs ,
226+ in_channels = dim ,
227+ out_channels = dim * 2 ,
228+ kernel_size = (3 , 1 , 1 ),
229+ padding = (1 , 0 , 0 ),
230+ mesh = mesh ,
231+ dtype = dtype ,
232+ weights_dtype = weights_dtype ,
233+ precision = precision
234+ )
235+
192236 elif mode == "downsample2d" :
193- # FIX: Use Sequential/Wrapper to match checkpoint keys if needed,
194- # but ZeroPaddedConv2D is a Module itself, so direct assignment is likely fine unless checkpoint wrapped it.
195- # Based on error log, downsample keys were missing too.
196- self .resample = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (3 , 3 ), stride = (2 , 2 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
237+ # Downsample 2D is often just a strided conv.
238+ # Error log suggested keys like 'downsample_conv' were missing in previous attempts?
239+ # Let's look at the error: 'resample', 'layers', 1...
240+ # This implies downsample might ALSO be a Sequential in the checkpoint?
241+ # Usually downsample is just a Conv.
242+ # Let's use the attribute name 'resample' to be safe if it matches the error key path structure.
243+ self .resample = ZeroPaddedConv2D (
244+ dim = dim ,
245+ rngs = rngs ,
246+ kernel_size = (3 , 3 ),
247+ stride = (2 , 2 ),
248+ mesh = mesh ,
249+ dtype = dtype ,
250+ weights_dtype = weights_dtype ,
251+ precision = precision
252+ )
253+
197254 elif mode == "downsample3d" :
198- # 3D mode explicit
199- self .downsample_conv = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (3 , 3 ), stride = (2 , 2 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
200- self .time_conv = WanCausalConv3d (rngs = rngs , in_channels = dim , out_channels = dim , kernel_size = (3 , 1 , 1 ), stride = (2 , 1 , 1 ), padding = (0 , 0 , 0 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
255+ self .downsample_conv = ZeroPaddedConv2D (
256+ dim = dim ,
257+ rngs = rngs ,
258+ kernel_size = (3 , 3 ),
259+ stride = (2 , 2 ),
260+ mesh = mesh ,
261+ dtype = dtype ,
262+ weights_dtype = weights_dtype ,
263+ precision = precision
264+ )
265+ self .time_conv = WanCausalConv3d (
266+ rngs = rngs ,
267+ in_channels = dim ,
268+ out_channels = dim ,
269+ kernel_size = (3 , 1 , 1 ),
270+ stride = (2 , 1 , 1 ),
271+ padding = (0 , 0 , 0 ),
272+ mesh = mesh ,
273+ dtype = dtype ,
274+ weights_dtype = weights_dtype ,
275+ precision = precision
276+ )
201277 else :
202278 self .resample = Identity ()
203279
@@ -215,9 +291,9 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
215291 new_cache = {}
216292
217293 if self .mode == "upsample2d" :
218- # Use self.resample (Sequential)
219294 b , t , h , w , c = x .shape
220295 x = x .reshape (b * t , h , w , c )
296+ # Using Sequential
221297 x = self .resample (x )
222298 h_new , w_new , c_new = x .shape [1 :]
223299 x = x .reshape (b , t , h_new , w_new , c_new )
@@ -239,9 +315,11 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
239315 x = x .reshape (b , t , h_new , w_new , c_new )
240316
241317 elif self .mode == "downsample2d" :
242- # Use self.resample (ZeroPaddedConv2D)
243318 b , t , h , w , c = x .shape
244319 x = x .reshape (b * t , h , w , c )
320+ # ZeroPaddedConv2D returns (out, cache) because of wrapper,
321+ # but Sequential might behave differently.
322+ # Here self.resample is ZeroPaddedConv2D directly.
245323 x , _ = self .resample (x , None )
246324 h_new , w_new , c_new = x .shape [1 :]
247325 x = x .reshape (b , t , h_new , w_new , c_new )
@@ -258,7 +336,12 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
258336
259337 else :
260338 if hasattr (self , "resample" ):
261- x , _ = self .resample (x , None )
339+ # Identity check
340+ if isinstance (self .resample , Identity ):
341+ x , _ = self .resample (x , None )
342+ else :
343+ # Just in case it falls here
344+ x = self .resample (x )
262345
263346 return x , new_cache
264347
0 commit comments