Skip to content

Commit dc4a830

Browse files
committed
Fix
1 parent 08dc2c1 commit dc4a830

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,11 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
174174
nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
175175
)
176176
elif mode == "upsample3d":
177-
self.upsample = WanUpsample(scale_factor=(2.0, 2.0), method="nearest")
178-
self.conv = nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
177+
# 3D mode ALSO needs Sequential for the spatial part to match checkpoints
178+
self.resample = nnx.Sequential(
179+
WanUpsample(scale_factor=(2.0, 2.0), method="nearest"),
180+
nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
181+
)
179182
self.time_conv = WanCausalConv3d(rngs=rngs, in_channels=dim, out_channels=dim * 2, kernel_size=(3, 1, 1), padding=(1, 0, 0), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
180183
elif mode == "downsample2d":
181184
self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)

0 commit comments

Comments
 (0)