Skip to content

Commit 7fa7406

Browse files
committed
Fix
1 parent 90f6b06 commit 7fa7406

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,10 @@ def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencode
583583
init_cache = self.encoder.init_cache(b, h, w, x.dtype)
584584

585585
def scan_fn(carry, input_slice):
586+
input_slice = jnp.expand_dims(input_slice, 1)
586587
out_slice, new_carry = self.encoder(input_slice, carry)
588+
# Squeeze Time dimension for scan stacking
589+
out_slice = jnp.squeeze(out_slice, 1)
587590
return new_carry, out_slice
588591

589592
final_cache, encoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)
@@ -607,7 +610,10 @@ def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOut
607610
init_cache = self.decoder.init_cache(b, h, w, x.dtype)
608611

609612
def scan_fn(carry, input_slice):
613+
input_slice = jnp.expand_dims(input_slice, 1)
610614
out_slice, new_carry = self.decoder(input_slice, carry)
615+
# Squeeze Time dimension for scan stacking
616+
out_slice = jnp.squeeze(out_slice, 1)
611617
return new_carry, out_slice
612618

613619
final_cache, decoded_frames = jax.lax.scan(scan_fn, init_cache, x_scan)

0 commit comments

Comments
 (0)