Skip to content

Commit 7918a6b

Browse files
committed
Fix
1 parent 7fa7406 commit 7918a6b

1 file changed

Lines changed: 16 additions & 8 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def initialize_cache(self, batch_size, height, width, dtype):
9595
cache = jnp.zeros((batch_size, CACHE_T, height, width, self.conv.in_features), dtype=dtype)
9696

9797
# OPTIMIZATION: Spatial Partitioning on Initialization
98-
# FIX: Check divisibility before sharding
9998
if self.mesh is not None and "fsdp" in self.mesh.axis_names:
10099
num_fsdp_devices = self.mesh.shape["fsdp"]
101100
# Axis 2 is Height
@@ -114,7 +113,6 @@ def initialize_cache(self, batch_size, height, width, dtype):
114113

115114
def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> Tuple[jax.Array, jax.Array]:
116115
# OPTIMIZATION: Spatial Partitioning during execution
117-
# FIX: Check divisibility
118116
if self.mesh is not None and "fsdp" in self.mesh.axis_names:
119117
height = x.shape[2]
120118
width = x.shape[3]
@@ -240,7 +238,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
240238
if self.mode == "upsample2d":
241239
b, t, h, w, c = x.shape
242240
x = x.reshape(b * t, h, w, c)
243-
x = self.resample(x)
241+
x = self.resample(x) # Sequential
244242
h_new, w_new, c_new = x.shape[1:]
245243
x = x.reshape(b, t, h_new, w_new, c_new)
246244

@@ -262,14 +260,14 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
262260
elif self.mode == "downsample2d":
263261
b, t, h, w, c = x.shape
264262
x = x.reshape(b * t, h, w, c)
265-
x, _ = self.resample(x, None)
263+
x, _ = self.resample(x, None) # ZeroPaddedConv2D
266264
h_new, w_new, c_new = x.shape[1:]
267265
x = x.reshape(b, t, h_new, w_new, c_new)
268266

269267
elif self.mode == "downsample3d":
270268
b, t, h, w, c = x.shape
271269
x = x.reshape(b * t, h, w, c)
272-
x, _ = self.resample(x, None)
270+
x, _ = self.resample(x, None) # ZeroPaddedConv2D
273271
h_new, w_new, c_new = x.shape[1:]
274272
x = x.reshape(b, t, h_new, w_new, c_new)
275273

@@ -583,6 +581,7 @@ def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencode
583581
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
584582

585583
def scan_fn(carry, input_slice):
584+
# Expand Time dimension for Conv3d
586585
input_slice = jnp.expand_dims(input_slice, 1)
587586
out_slice, new_carry = self.encoder(input_slice, carry)
588587
# Squeeze Time dimension for scan stacking
@@ -610,14 +609,23 @@ def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOut
610609
init_cache = self.decoder.init_cache(b, h, w, x.dtype)
611610

612611
def scan_fn(carry, input_slice):
612+
# Expand Time dimension for Conv3d
613613
input_slice = jnp.expand_dims(input_slice, 1)
614614
out_slice, new_carry = self.decoder(input_slice, carry)
615-
# Squeeze Time dimension for scan stacking
616-
out_slice = jnp.squeeze(out_slice, 1)
615+
# Don't squeeze here; keep the upsampled frames (B, 4, H, W, C)
617616
return new_carry, out_slice
618617

619618
final_cache, decoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
620-
decoded = jnp.swapaxes(decoded_frames, 0, 1)
619+
620+
# decoded_frames shape: (T_lat, B, 4, H, W, C)
621+
# We need to flatten T_lat and 4.
622+
# Transpose to (B, T_lat, 4, H, W, C)
623+
decoded = jnp.transpose(decoded_frames, (1, 0, 2, 3, 4, 5))
624+
625+
# Reshape to (B, T_lat*4, H, W, C)
626+
b, t_lat, t_sub, h, w, c = decoded.shape
627+
decoded = decoded.reshape(b, t_lat * t_sub, h, w, c)
628+
621629
out = jnp.clip(decoded, min=-1.0, max=1.0)
622630

623631
if not return_dict: return (out,)

0 commit comments

Comments
 (0)