@@ -265,24 +265,16 @@ def loop_body(step, vals):
265265 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
266266
267267 latents_input = latents
268- condition_input = condition
269- prompt_embeds_input = prompt_embeds
270- image_embeds_input = image_embeds
271268 if do_classifier_free_guidance :
272269 latents_input = jnp .concatenate ([latents , latents ], axis = 0 )
273- condition_input = jnp .concatenate ([condition , condition ], axis = 0 )
274- prompt_embeds_input = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
275- if image_embeds is not None :
276- image_embeds_input = jnp .concatenate ([image_embeds , image_embeds ], axis = 0 )
277-
278270
279271 if expand_timesteps :
280- latent_model_input = (1 - first_frame_mask ) * condition_input + first_frame_mask * latents_input
272+ latent_model_input = (1 - first_frame_mask ) * condition + first_frame_mask * latents_input
281273 temp_ts = (first_frame_mask [0 ][0 ][:, ::2 , ::2 ] * t ).flatten ()
282274 timestep = jnp .expand_dims (temp_ts , axis = 0 )
283275 timestep = jnp .broadcast_to (timestep , (latents_input .shape [0 ], temp_ts .shape [0 ]))
284276 else :
285- latent_model_input = jnp .concatenate ([latents_input , condition_input ], axis = - 1 )
277+ latent_model_input = jnp .concatenate ([latents_input , condition ], axis = - 1 )
286278 timestep = jnp .broadcast_to (t , latents_input .shape [0 ])
287279
288280
@@ -316,7 +308,7 @@ def low_noise_branch(operands):
316308 use_high_noise ,
317309 high_noise_branch ,
318310 low_noise_branch ,
319- (latent_model_input , timestep , prompt_embeds_input , image_embeds_input )
311+ (latent_model_input , timestep , prompt_embeds , image_embeds )
320312 )
321313
322314 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
0 commit comments