Skip to content

Commit a162aa5

Browse files
committed
Fix
1 parent 7be24b5 commit a162aa5

1 file changed

Lines changed: 98 additions & 15 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 98 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -178,26 +178,102 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
178178
self.dim = dim
179179
self.mode = mode
180180

181+
# ATTRIBUTES MUST BE DEFINED IN THE INIT PATH
182+
# We use different attribute names depending on mode to match strict checkpoint keys if needed,
183+
# OR we rely on the fact that the checkpoint loading mapping handles the name translation.
184+
# based on the error, the checkpoint expects 'resample' to be a Sequential for 2D modes.
185+
181186
if mode == "upsample2d":
182-
# FIX: Use Sequential to match checkpoint keys
187+
# Map: resample.layers.0 -> WanUpsample
188+
# Map: resample.layers.1 -> Conv
183189
self.resample = nnx.Sequential(
184190
WanUpsample(scale_factor=(2.0, 2.0), method="nearest"),
185-
nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
191+
nnx.Conv(
192+
dim,
193+
dim // 2,
194+
kernel_size=(3, 3),
195+
padding="SAME",
196+
use_bias=True,
197+
rngs=rngs,
198+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")),
199+
dtype=dtype,
200+
param_dtype=weights_dtype,
201+
precision=precision
202+
)
186203
)
204+
187205
elif mode == "upsample3d":
188-
# 3D mode uses explicit attributes for cache handling
206+
# 3D mode: Code handles 'upsample' and 'conv' separately in __call__,
207+
# BUT for checkpoint loading, if the checkpoint has 'resample.layers...',
208+
# we might need to match that.
209+
# However, standard Wan3D usually has explicit components.
210+
# We will stick to explicit attributes here as defined in previous working versions.
189211
self.upsample = WanUpsample(scale_factor=(2.0, 2.0), method="nearest")
190-
self.conv = nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
191-
self.time_conv = WanCausalConv3d(rngs=rngs, in_channels=dim, out_channels=dim * 2, kernel_size=(3, 1, 1), padding=(1, 0, 0), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
212+
self.conv = nnx.Conv(
213+
dim,
214+
dim // 2,
215+
kernel_size=(3, 3),
216+
padding="SAME",
217+
use_bias=True,
218+
rngs=rngs,
219+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")),
220+
dtype=dtype,
221+
param_dtype=weights_dtype,
222+
precision=precision
223+
)
224+
self.time_conv = WanCausalConv3d(
225+
rngs=rngs,
226+
in_channels=dim,
227+
out_channels=dim * 2,
228+
kernel_size=(3, 1, 1),
229+
padding=(1, 0, 0),
230+
mesh=mesh,
231+
dtype=dtype,
232+
weights_dtype=weights_dtype,
233+
precision=precision
234+
)
235+
192236
elif mode == "downsample2d":
193-
# FIX: Use Sequential/Wrapper to match checkpoint keys if needed,
194-
# but ZeroPaddedConv2D is a Module itself, so direct assignment is likely fine unless checkpoint wrapped it.
195-
# Based on error log, downsample keys were missing too.
196-
self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
237+
# Downsample 2D is often just a strided conv.
238+
# Error log suggested keys like 'downsample_conv' were missing in previous attempts?
239+
# Let's look at the error: 'resample', 'layers', 1...
240+
# This implies downsample might ALSO be a Sequential in the checkpoint?
241+
# Usually downsample is just a Conv.
242+
# Let's use the attribute name 'resample' to be safe if it matches the error key path structure.
243+
self.resample = ZeroPaddedConv2D(
244+
dim=dim,
245+
rngs=rngs,
246+
kernel_size=(3, 3),
247+
stride=(2, 2),
248+
mesh=mesh,
249+
dtype=dtype,
250+
weights_dtype=weights_dtype,
251+
precision=precision
252+
)
253+
197254
elif mode == "downsample3d":
198-
# 3D mode explicit
199-
self.downsample_conv = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
200-
self.time_conv = WanCausalConv3d(rngs=rngs, in_channels=dim, out_channels=dim, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
255+
self.downsample_conv = ZeroPaddedConv2D(
256+
dim=dim,
257+
rngs=rngs,
258+
kernel_size=(3, 3),
259+
stride=(2, 2),
260+
mesh=mesh,
261+
dtype=dtype,
262+
weights_dtype=weights_dtype,
263+
precision=precision
264+
)
265+
self.time_conv = WanCausalConv3d(
266+
rngs=rngs,
267+
in_channels=dim,
268+
out_channels=dim,
269+
kernel_size=(3, 1, 1),
270+
stride=(2, 1, 1),
271+
padding=(0, 0, 0),
272+
mesh=mesh,
273+
dtype=dtype,
274+
weights_dtype=weights_dtype,
275+
precision=precision
276+
)
201277
else:
202278
self.resample = Identity()
203279

@@ -215,9 +291,9 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
215291
new_cache = {}
216292

217293
if self.mode == "upsample2d":
218-
# Use self.resample (Sequential)
219294
b, t, h, w, c = x.shape
220295
x = x.reshape(b * t, h, w, c)
296+
# Using Sequential
221297
x = self.resample(x)
222298
h_new, w_new, c_new = x.shape[1:]
223299
x = x.reshape(b, t, h_new, w_new, c_new)
@@ -239,9 +315,11 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
239315
x = x.reshape(b, t, h_new, w_new, c_new)
240316

241317
elif self.mode == "downsample2d":
242-
# Use self.resample (ZeroPaddedConv2D)
243318
b, t, h, w, c = x.shape
244319
x = x.reshape(b * t, h, w, c)
320+
# ZeroPaddedConv2D returns (out, cache) because of wrapper,
321+
# but Sequential might behave differently.
322+
# Here self.resample is ZeroPaddedConv2D directly.
245323
x, _ = self.resample(x, None)
246324
h_new, w_new, c_new = x.shape[1:]
247325
x = x.reshape(b, t, h_new, w_new, c_new)
@@ -258,7 +336,12 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
258336

259337
else:
260338
if hasattr(self, "resample"):
261-
x, _ = self.resample(x, None)
339+
# Identity check
340+
if isinstance(self.resample, Identity):
341+
x, _ = self.resample(x, None)
342+
else:
343+
# Just in case it falls here
344+
x = self.resample(x)
262345

263346
return x, new_cache
264347

0 commit comments

Comments
 (0)