@@ -265,16 +265,24 @@ def loop_body(step, vals):
265265 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
266266
267267 latents_input = latents
268+ condition_input = condition
269+ prompt_embeds_input = prompt_embeds
270+ image_embeds_input = image_embeds
268271 if do_classifier_free_guidance :
269272 latents_input = jnp .concatenate ([latents , latents ], axis = 0 )
273+ condition_input = jnp .concatenate ([condition , condition ], axis = 0 )
274+ prompt_embeds_input = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
275+ if image_embeds is not None :
276+ image_embeds_input = jnp .concatenate ([image_embeds , image_embeds ], axis = 0 )
277+
270278
271279 if expand_timesteps :
272- latent_model_input = (1 - first_frame_mask ) * condition + first_frame_mask * latents_input
280+ latent_model_input = (1 - first_frame_mask ) * condition_input + first_frame_mask * latents_input
273281 temp_ts = (first_frame_mask [0 ][0 ][:, ::2 , ::2 ] * t ).flatten ()
274282 timestep = jnp .expand_dims (temp_ts , axis = 0 )
275283 timestep = jnp .broadcast_to (timestep , (latents_input .shape [0 ], temp_ts .shape [0 ]))
276284 else :
277- latent_model_input = jnp .concatenate ([latents_input , condition ], axis = - 1 )
285+ latent_model_input = jnp .concatenate ([latents_input , condition_input ], axis = - 1 )
278286 timestep = jnp .broadcast_to (t , latents_input .shape [0 ])
279287
280288
@@ -308,7 +316,7 @@ def low_noise_branch(operands):
308316 use_high_noise ,
309317 high_noise_branch ,
310318 low_noise_branch ,
311- (latent_model_input , timestep , prompt_embeds , image_embeds )
319+ (latent_model_input , timestep , prompt_embeds_input , image_embeds_input )
312320 )
313321
314322 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
0 commit comments