@@ -269,30 +269,23 @@ def loop_body(step, vals):
269269
270270 latents_input = latents
271271 if do_classifier_free_guidance :
272- condition_input = jnp .concatenate ([condition , condition ], axis = 0 )
273272 latents_input = jnp .concatenate ([latents , latents ], axis = 0 )
274- else :
275- condition_input = condition
276273
277- latent_model_input = jnp .concatenate ([latents_input , condition_input ], axis = - 1 )
274+ latent_model_input = jnp .concatenate ([latents_input , condition ], axis = - 1 )
278275 timestep = jnp .broadcast_to (t , latents_input .shape [0 ])
279276 latent_model_input = jnp .transpose (latent_model_input , (0 , 4 , 1 , 2 , 3 ))
280277
281278 prompt_embeds_input = prompt_embeds
282- image_embeds_input = image_embeds
283279 if do_classifier_free_guidance :
284280 prompt_embeds_input = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
285- if image_embeds is not None :
286- image_embeds_input = jnp .concatenate ([image_embeds , image_embeds ], axis = 0 )
287-
288281
289282
290283 noise_pred , latents = transformer_forward_pass (
291284 graphdef , sharded_state , rest_of_state ,
292285 latent_model_input , timestep , prompt_embeds_input ,
293286 do_classifier_free_guidance = do_classifier_free_guidance ,
294287 guidance_scale = guidance_scale ,
295- encoder_hidden_states_image = image_embeds_input ,
288+ encoder_hidden_states_image = image_embeds ,
296289 )
297290 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
298291
0 commit comments