@@ -85,6 +85,7 @@ def prepare_latents(
8585 rng : jax .Array ,
8686 latents : Optional [jax .Array ] = None ,
8787 last_image : Optional [jax .Array ] = None ,
88+ num_videos_per_prompt : int = 1 ,
8889 ) -> Tuple [jax .Array , jax .Array , Optional [jax .Array ]]:
8990
9091 if hasattr (image , "detach" ):
@@ -96,12 +97,17 @@ def prepare_latents(
9697 last_image = last_image .detach ().cpu ().numpy ()
9798 last_image = jnp .array (last_image )
9899
100+ if num_videos_per_prompt > 1 :
101+ image = jnp .repeat (image , num_videos_per_prompt , axis = 0 )
102+ if last_image is not None :
103+ last_image = jnp .repeat (last_image , num_videos_per_prompt , axis = 0 )
104+
99105 num_channels_latents = self .vae .z_dim
100106 num_latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
101107 latent_height = height // self .vae_scale_factor_spatial
102108 latent_width = width // self .vae_scale_factor_spatial
103109
104- shape = (batch_size , num_channels_latents , num_latent_frames , latent_height , latent_width )
110+ shape = (batch_size , num_latent_frames , latent_height , latent_width , num_channels_latents )
105111
106112 if latents is None :
107113 latents = randn_tensor (shape , rng , self .config , dtype )
@@ -119,7 +125,6 @@ def prepare_latents(
119125 mask_lat_size = mask_lat_size .reshape (
120126 batch_size , - 1 , self .vae_scale_factor_temporal , latent_height , latent_width
121127 )
122- mask_lat_size = jnp .swapaxes (mask_lat_size , 1 , 2 )
123128 mask_lat_size = jnp .transpose (mask_lat_size , (0 , 2 , 3 , 4 , 1 ))
124129 condition = jnp .concatenate ([mask_lat_size , latent_condition ], axis = - 1 )
125130
@@ -146,13 +151,20 @@ def __call__(
146151 output_type : Optional [str ] = "np" ,
147152 rng : Optional [jax .Array ] = None ,
148153 ):
154+
155+ if num_videos_per_prompt == 1 :
156+ n_devices = jax .device_count ()
157+ if n_devices > 1 :
158+ num_videos_per_prompt = n_devices
159+
149160 height = height or self .config .height
150161 width = width or self .config .width
151162 num_frames = num_frames or self .config .num_frames
152163
153164 prompt_embeds , negative_prompt_embeds , image_embeds , effective_batch_size = self ._prepare_model_inputs_i2v (
154165 prompt , image , negative_prompt , num_videos_per_prompt , max_sequence_length ,
155- prompt_embeds , negative_prompt_embeds , image_embeds , last_image
166+ prompt_embeds , negative_prompt_embeds , image_embeds , last_image ,
167+ num_videos_per_prompt = num_videos_per_prompt ,
156168 )
157169
158170 image_tensor = self .video_processor .preprocess (image , height = height , width = width )
@@ -174,6 +186,7 @@ def __call__(
174186 rng = latents_rng ,
175187 latents = latents ,
176188 last_image = last_image_tensor ,
189+ num_videos_per_prompt = num_videos_per_prompt ,
177190 )
178191
179192 scheduler_state = self .scheduler .set_timesteps (
@@ -254,7 +267,7 @@ def loop_body(step, vals):
254267 if do_classifier_free_guidance :
255268 latents_input = jnp .concatenate ([latents , latents ], axis = 0 )
256269
257- latent_model_input = jnp .concatenate ([latents_input , condition ], axis = 1 )
270+ latent_model_input = jnp .concatenate ([latents_input , condition ], axis = - 1 )
258271 timestep = jnp .broadcast_to (t , latents .shape [0 ])
259272
260273
0 commit comments