@@ -269,20 +269,30 @@ def loop_body(step, vals):
269269
270270 latents_input = latents
271271 if do_classifier_free_guidance :
272+ condition_input = jnp .concatenate ([condition , condition ], axis = 0 )
272273 latents_input = jnp .concatenate ([latents , latents ], axis = 0 )
274+ else :
275+ condition_input = condition
273276
274- latent_model_input = jnp .concatenate ([latents_input , condition ], axis = - 1 )
277+ latent_model_input = jnp .concatenate ([latents_input , condition_input ], axis = - 1 )
275278 timestep = jnp .broadcast_to (t , latents_input .shape [0 ])
276279 latent_model_input = jnp .transpose (latent_model_input , (0 , 4 , 1 , 2 , 3 ))
277- timestep = jnp .broadcast_to (t , latents .shape [0 ])
280+
281+ prompt_embeds_input = prompt_embeds
282+ image_embeds_input = image_embeds
283+ if do_classifier_free_guidance :
284+ prompt_embeds_input = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
285+ if image_embeds is not None :
286+ image_embeds_input = jnp .concatenate ([image_embeds , image_embeds ], axis = 0 )
287+
278288
279289
280290 noise_pred , latents = transformer_forward_pass (
281291 graphdef , sharded_state , rest_of_state ,
282- latent_model_input , timestep , prompt_embeds ,
292+ latent_model_input , timestep , prompt_embeds_input ,
283293 do_classifier_free_guidance = do_classifier_free_guidance ,
284294 guidance_scale = guidance_scale ,
285- encoder_hidden_states_image = image_embeds ,
295+ encoder_hidden_states_image = image_embeds_input ,
286296 )
287297 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
288298
0 commit comments