Skip to content

Commit 7be24b5

Browse files
committed
Fix
1 parent 433b2d4 commit 7be24b5

1 file changed

Lines changed: 22 additions & 14 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -177,26 +177,35 @@ class WanResample(nnx.Module):
177177
def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh = None, dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None):
178178
self.dim = dim
179179
self.mode = mode
180-
180+
181181
if mode == "upsample2d":
182-
self.upsample = WanUpsample(scale_factor=(2.0, 2.0), method="nearest")
183-
self.conv = nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
182+
# FIX: Use Sequential to match checkpoint keys
183+
self.resample = nnx.Sequential(
184+
WanUpsample(scale_factor=(2.0, 2.0), method="nearest"),
185+
nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
186+
)
184187
elif mode == "upsample3d":
188+
# 3D mode uses explicit attributes for cache handling
185189
self.upsample = WanUpsample(scale_factor=(2.0, 2.0), method="nearest")
186190
self.conv = nnx.Conv(dim, dim // 2, kernel_size=(3, 3), padding="SAME", use_bias=True, rngs=rngs, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")), dtype=dtype, param_dtype=weights_dtype, precision=precision)
187191
self.time_conv = WanCausalConv3d(rngs=rngs, in_channels=dim, out_channels=dim * 2, kernel_size=(3, 1, 1), padding=(1, 0, 0), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
188192
elif mode == "downsample2d":
189-
self.downsample_conv = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
193+
# FIX: Use Sequential/Wrapper to match checkpoint keys if needed,
194+
# but ZeroPaddedConv2D is a Module itself, so direct assignment is likely fine unless checkpoint wrapped it.
195+
# Based on error log, downsample keys were missing too.
196+
self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
190197
elif mode == "downsample3d":
198+
# 3D mode explicit
191199
self.downsample_conv = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(3, 3), stride=(2, 2), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
192200
self.time_conv = WanCausalConv3d(rngs=rngs, in_channels=dim, out_channels=dim, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0), mesh=mesh, dtype=dtype, weights_dtype=weights_dtype, precision=precision)
201+
else:
202+
self.resample = Identity()
193203

194204
def initialize_cache(self, batch_size, height, width, dtype):
195205
cache = {}
196206
if hasattr(self, "time_conv"):
197207
h_curr, w_curr = height, width
198208
if self.mode == "downsample3d":
199-
# Resample (stride 2) happens before time conv
200209
h_curr, w_curr = height // 2, width // 2
201210
cache["time_conv"] = self.time_conv.initialize_cache(batch_size, h_curr, w_curr, dtype)
202211
return cache
@@ -206,25 +215,22 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
206215
new_cache = {}
207216

208217
if self.mode == "upsample2d":
218+
# Use self.resample (Sequential)
209219
b, t, h, w, c = x.shape
210220
x = x.reshape(b * t, h, w, c)
211-
x = self.upsample(x)
212-
x = self.conv(x)
221+
x = self.resample(x)
213222
h_new, w_new, c_new = x.shape[1:]
214223
x = x.reshape(b, t, h_new, w_new, c_new)
215224

216225
elif self.mode == "upsample3d":
217-
# TimeConv -> Reshape -> Resample
218226
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
219227
new_cache["time_conv"] = tc_cache
220228

221229
b, t, h, w, c = x.shape
222-
# Split channels for time upsample
223230
x = x.reshape(b, t, h, w, 2, c // 2)
224231
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
225232
x = x.reshape(b, t * 2, h, w, c // 2)
226233

227-
# Spatial resample
228234
b, t, h, w, c = x.shape
229235
x = x.reshape(b * t, h, w, c)
230236
x = self.upsample(x)
@@ -233,9 +239,10 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
233239
x = x.reshape(b, t, h_new, w_new, c_new)
234240

235241
elif self.mode == "downsample2d":
242+
# Use self.resample (ZeroPaddedConv2D)
236243
b, t, h, w, c = x.shape
237244
x = x.reshape(b * t, h, w, c)
238-
x, _ = self.downsample_conv(x, None)
245+
x, _ = self.resample(x, None)
239246
h_new, w_new, c_new = x.shape[1:]
240247
x = x.reshape(b, t, h_new, w_new, c_new)
241248

@@ -245,14 +252,15 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
245252
x, _ = self.downsample_conv(x, None)
246253
h_new, w_new, c_new = x.shape[1:]
247254
x = x.reshape(b, t, h_new, w_new, c_new)
248-
255+
249256
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
250257
new_cache["time_conv"] = tc_cache
258+
251259
else:
252260
if hasattr(self, "resample"):
253-
x, _ = self.resample(x, None)
254-
return x, new_cache
261+
x, _ = self.resample(x, None)
255262

263+
return x, new_cache
256264

257265
class WanResidualBlock(nnx.Module):
258266
def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs, dropout: float = 0.0, non_linearity: str = "silu", mesh: jax.sharding.Mesh = None, dtype: jnp.dtype = jnp.float32, weights_dtype: jnp.dtype = jnp.float32, precision: jax.lax.Precision = None):

0 commit comments

Comments
 (0)