Skip to content

Commit ec668df

Browse files
committed
modify downsample3d
1 parent 081feff commit ec668df

1 file changed

Lines changed: 21 additions & 10 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -421,18 +421,29 @@ def __call__(
421421
x = x.reshape(b, t, h_new, w_new, c_new)
422422

423423
elif self.mode == "downsample3d":
424-
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
425-
new_cache["time_conv"] = tc_cache
426-
print(f"WanResample ({self.mode}) after time_conv: {x.shape}")
424+
if x.shape[1] >= self.time_conv.kernel_size[0]:
425+
x, tc_cache = self.time_conv(x, cache.get("time_conv"))
426+
new_cache["time_conv"] = tc_cache
427+
print(f"WanResample ({self.mode}) after time_conv: {x.shape}")
428+
else:
429+
# Skip temporal downsampling if not enough frames
430+
print(f"WanResample ({self.mode}): Skipping time_conv, input time dim {x.shape[1]} < kernel {self.time_conv.kernel_size[0]}")
431+
new_cache["time_conv"] = cache.get("time_conv") # Pass through cache
427432

428433
b, t, h, w, c = x.shape
429-
x = x.reshape(b * t, h, w, c)
430-
print(f"WanResample ({self.mode}) reshaped for resample: {x.shape}")
431-
x, _ = self.resample(x, None)
432-
print(f"WanResample ({self.mode}) after resample: {x.shape}")
433-
h_new, w_new, c_new = x.shape[1:]
434-
x = x.reshape(b, t, h_new, w_new, c_new)
435-
434+
if b * t > 0:
435+
x = x.reshape(b * t, h, w, c)
436+
print(f"WanResample ({self.mode}) reshaped for resample: {x.shape}")
437+
x, _ = self.resample(x, None)
438+
print(f"WanResample ({self.mode}) after resample: {x.shape}")
439+
h_new, w_new, c_new = x.shape[1:]
440+
x = x.reshape(b, t, h_new, w_new, c_new)
441+
else:
442+
# If time dimension became 0, spatial shape changes, but batch and time are still 0
443+
h_new, w_new = h // self.resample.conv.strides[0], w // self.resample.conv.strides[1]
444+
c_new = self.resample.conv.out_features
445+
x = jnp.zeros((b, t, h_new, w_new, c_new), dtype=x.dtype)
446+
print(f"WanResample ({self.mode}): Spatial downsample output shape {x.shape} (due to t=0)")
436447
else:
437448
if hasattr(self, "resample"):
438449
if isinstance(self.resample, Identity):

0 commit comments

Comments
 (0)