Skip to content

Commit 557ee7d

Browse files
committed
Fix errors
1 parent 8a49410 commit 557ee7d

1 file changed

Lines changed: 5 additions & 8 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,6 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
178178
self.dim = dim
179179
self.mode = mode
180180
self.time_conv = None
181-
182-
# We unpack nnx.Sequential to handle cache logic explicitly
183-
self.upsample = None
184-
self.conv = None
185-
self.downsample_conv = None
186181

187182
if mode == "upsample2d":
188183
self.upsample = WanUpsample(scale_factor=(2.0, 2.0), method="nearest")
@@ -199,7 +194,7 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
199194

200195
def initialize_cache(self, batch_size, height, width, dtype):
201196
cache = {}
202-
if self.time_conv is not None:
197+
if hasattr(self, "time_conv"):
203198
h_curr, w_curr = height, width
204199
if self.mode == "downsample3d":
205200
# Resample (stride 2) happens before time conv
@@ -251,10 +246,12 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
251246
x, _ = self.downsample_conv(x, None)
252247
h_new, w_new, c_new = x.shape[1:]
253248
x = x.reshape(b, t, h_new, w_new, c_new)
254-
249+
255250
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
256251
new_cache["time_conv"] = tc_cache
257-
252+
else:
253+
if hasattr(self, "resample"):
254+
x, _ = self.resample(x, None)
258255
return x, new_cache
259256

260257

0 commit comments

Comments
 (0)