@@ -583,7 +583,10 @@ def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencode
583583 init_cache = self .encoder .init_cache (b , h , w , x .dtype )
584584
585585 def scan_fn (carry , input_slice ):
586+ input_slice = jnp .expand_dims (input_slice , 1 )
586587 out_slice , new_carry = self .encoder (input_slice , carry )
588+ # Squeeze Time dimension for scan stacking
589+ out_slice = jnp .squeeze (out_slice , 1 )
587590 return new_carry , out_slice
588591
589592 final_cache , encoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
@@ -607,7 +610,10 @@ def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOut
607610 init_cache = self .decoder .init_cache (b , h , w , x .dtype )
608611
609612 def scan_fn (carry , input_slice ):
613+ input_slice = jnp .expand_dims (input_slice , 1 )
610614 out_slice , new_carry = self .decoder (input_slice , carry )
615+ # Squeeze Time dimension for scan stacking
616+ out_slice = jnp .squeeze (out_slice , 1 )
611617 return new_carry , out_slice
612618
613619 final_cache , decoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
0 commit comments