@@ -174,8 +174,11 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
174174 nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
175175 )
176176 elif mode == "upsample3d" :
177- self .upsample = WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" )
178- self .conv = nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
177+ # 3D mode ALSO needs Sequential for the spatial part to match checkpoints
178+ self .resample = nnx .Sequential (
179+ WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" ),
180+ nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
181+ )
179182 self .time_conv = WanCausalConv3d (rngs = rngs , in_channels = dim , out_channels = dim * 2 , kernel_size = (3 , 1 , 1 ), padding = (1 , 0 , 0 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
180183 elif mode == "downsample2d" :
181184 self .resample = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (3 , 3 ), stride = (2 , 2 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
0 commit comments