Skip to content

Commit 08dc2c1

Browse files
committed
Fix
1 parent cd093e5 commit 08dc2c1

1 file changed

Lines changed: 14 additions & 12 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,18 +168,14 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
168168
self.dim = dim
169169
self.mode = mode
170170

171-
# FIX: Removed pre-initialization of attributes to None to avoid NNX errors.
172-
173171
if mode == "upsample2d":
174172
self.resample = nnx.Sequential(
175173
WanUpsample(scale_factor=(2.0, 2.0), method="nearest"),
176174
nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
177175
)
178176
elif mode == "upsample3d":
179-
self.resample = nnx.Sequential(
180-
WanUpsample(scale_factor=(2.0, 2.0), method="nearest"),
181-
nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
182-
)
177+
self.upsample = WanUpsample(scale_factor=(2.0, 2.0), method="nearest")
178+
self.conv = nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
183179
self.time_conv = WanCausalConv3d(rngs=rngs, in_channels=dim, out_channels=dim * 2, kernel_size=(3, 1, 1), padding=(1, 0, 0), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
184180
elif mode == "downsample2d":
185181
self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
@@ -220,7 +216,8 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
220216

221217
b, t, h, w, c = x.shape
222218
x = x.reshape(b * t, h, w, c)
223-
x = self.resample(x) # Sequential
219+
x = self.upsample(x)
220+
x = self.conv(x)
224221
h_new, w_new, c_new = x.shape[1:]
225222
x = x.reshape(b, t, h_new, w_new, c_new)
226223

@@ -234,18 +231,19 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
234231
elif self.mode == "downsample3d":
235232
b, t, h, w, c = x.shape
236233
x = x.reshape(b * t, h, w, c)
237-
x, _ = self.resample(x, None) # ZeroPaddedConv2D
234+
x, _ = self.resample(x, None) # Fixed: use self.resample not self.downsample_conv
238235
h_new, w_new, c_new = x.shape[1:]
239236
x = x.reshape(b, t, h_new, w_new, c_new)
240237

241238
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
242239
new_cache["time_conv"] = tc_cache
243240

244241
else:
245-
if isinstance(self.resample, Identity):
246-
x, _ = self.resample(x, None)
247-
else:
248-
x = self.resample(x)
242+
if hasattr(self, "resample"):
243+
if isinstance(self.resample, Identity):
244+
x, _ = self.resample(x, None)
245+
else:
246+
x = self.resample(x)
249247

250248
return x, new_cache
251249

@@ -532,6 +530,10 @@ def __init__(self, rngs: nnx.Rngs, base_dim: int = 96, z_dim: int = 16, dim_mult
532530
self.temperal_downsample = temperal_downsample
533531
self.temporal_upsample = temperal_downsample[::-1]
534532

533+
# MISSING attributes added back
534+
self.latents_mean = latents_mean
535+
self.latents_std = latents_std
536+
535537
self.encoder = WanEncoder3d(rngs=rngs, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
536538
self.quant_conv = WanCausalConv3d(rngs=rngs, in_channels=z_dim * 2, out_channels=z_dim * 2, kernel_size=1, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
537539
self.post_quant_conv = WanCausalConv3d(rngs=rngs, in_channels=z_dim, out_channels=z_dim, kernel_size=1, mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)

0 commit comments

Comments
 (0)