@@ -112,15 +112,17 @@ def prepare_latents(
112112 if last_image is None :
113113 mask_lat_size = mask_lat_size .at [:, :, 1 :, :, :].set (0 )
114114 else :
115- mask_lat_size = mask_lat_size .at [:, :, 1 :- 1 , :, :].set (0 )
115+ mask_lat_size = mask_lat_size .at [:, :, 1 :- 1 , :, :].set (0 )
116116 first_frame_mask = mask_lat_size [:, :, 0 :1 ]
117117 first_frame_mask = jnp .repeat (first_frame_mask , self .vae_scale_factor_temporal , axis = 2 )
118118 mask_lat_size = jnp .concatenate ([first_frame_mask , mask_lat_size [:, :, 1 :]], axis = 2 )
119119 mask_lat_size = mask_lat_size .reshape (
120120 batch_size , - 1 , self .vae_scale_factor_temporal , latent_height , latent_width
121121 )
122122 mask_lat_size = jnp .swapaxes (mask_lat_size , 1 , 2 )
123- condition = jnp .concatenate ([mask_lat_size , latent_condition ], axis = 1 )
123+ mask_lat_size = jnp .transpose (mask_lat_size , (0 , 2 , 3 , 4 , 1 ))
124+ condition = jnp .concatenate ([mask_lat_size , latent_condition ], axis = - 1 )
125+
124126 return latents , condition , None
125127
126128
0 commit comments