@@ -95,7 +95,6 @@ def initialize_cache(self, batch_size, height, width, dtype):
9595 cache = jnp .zeros ((batch_size , CACHE_T , height , width , self .conv .in_features ), dtype = dtype )
9696
9797 # OPTIMIZATION: Spatial Partitioning on Initialization
98- # FIX: Check divisibility before sharding
9998 if self .mesh is not None and "fsdp" in self .mesh .axis_names :
10099 num_fsdp_devices = self .mesh .shape ["fsdp" ]
101100 # Axis 2 is Height
@@ -114,7 +113,6 @@ def initialize_cache(self, batch_size, height, width, dtype):
114113
115114 def __call__ (self , x : jax .Array , cache_x : Optional [jax .Array ] = None ) -> Tuple [jax .Array , jax .Array ]:
116115 # OPTIMIZATION: Spatial Partitioning during execution
117- # FIX: Check divisibility
118116 if self .mesh is not None and "fsdp" in self .mesh .axis_names :
119117 height = x .shape [2 ]
120118 width = x .shape [3 ]
@@ -240,7 +238,7 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
240238 if self .mode == "upsample2d" :
241239 b , t , h , w , c = x .shape
242240 x = x .reshape (b * t , h , w , c )
243- x = self .resample (x )
241+ x = self .resample (x ) # Sequential
244242 h_new , w_new , c_new = x .shape [1 :]
245243 x = x .reshape (b , t , h_new , w_new , c_new )
246244
@@ -262,14 +260,14 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
262260 elif self .mode == "downsample2d" :
263261 b , t , h , w , c = x .shape
264262 x = x .reshape (b * t , h , w , c )
265- x , _ = self .resample (x , None )
263+ x , _ = self .resample (x , None ) # ZeroPaddedConv2D
266264 h_new , w_new , c_new = x .shape [1 :]
267265 x = x .reshape (b , t , h_new , w_new , c_new )
268266
269267 elif self .mode == "downsample3d" :
270268 b , t , h , w , c = x .shape
271269 x = x .reshape (b * t , h , w , c )
272- x , _ = self .resample (x , None )
270+ x , _ = self .resample (x , None ) # ZeroPaddedConv2D
273271 h_new , w_new , c_new = x .shape [1 :]
274272 x = x .reshape (b , t , h_new , w_new , c_new )
275273
@@ -583,6 +581,7 @@ def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencode
583581 init_cache = self .encoder .init_cache (b , h , w , x .dtype )
584582
585583 def scan_fn (carry , input_slice ):
584+ # Expand Time dimension for Conv3d
586585 input_slice = jnp .expand_dims (input_slice , 1 )
587586 out_slice , new_carry = self .encoder (input_slice , carry )
588587 # Squeeze Time dimension for scan stacking
@@ -610,14 +609,23 @@ def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOut
610609 init_cache = self .decoder .init_cache (b , h , w , x .dtype )
611610
612611 def scan_fn (carry , input_slice ):
612+ # Expand Time dimension for Conv3d
613613 input_slice = jnp .expand_dims (input_slice , 1 )
614614 out_slice , new_carry = self .decoder (input_slice , carry )
615- # Squeeze Time dimension for scan stacking
616- out_slice = jnp .squeeze (out_slice , 1 )
615+ # Don't squeeze here; keep the upsampled frames (B, 4, H, W, C)
617616 return new_carry , out_slice
618617
619618 final_cache , decoded_frames = jax .lax .scan (scan_fn , init_cache , x_scan )
620- decoded = jnp .swapaxes (decoded_frames , 0 , 1 )
619+
620+ # decoded_frames shape: (T_lat, B, 4, H, W, C)
621+ # We need to flatten T_lat and 4.
622+ # Transpose to (B, T_lat, 4, H, W, C)
623+ decoded = jnp .transpose (decoded_frames , (1 , 0 , 2 , 3 , 4 , 5 ))
624+
625+ # Reshape to (B, T_lat*4, H, W, C)
626+ b , t_lat , t_sub , h , w , c = decoded .shape
627+ decoded = decoded .reshape (b , t_lat * t_sub , h , w , c )
628+
621629 out = jnp .clip (decoded , min = - 1.0 , max = 1.0 )
622630
623631 if not return_dict : return (out ,)
0 commit comments