@@ -274,29 +274,35 @@ def loop_body(step, vals):
274274 timestep = jnp .expand_dims (temp_ts , axis = 0 )
275275 timestep = jnp .broadcast_to (timestep , (latents_input .shape [0 ], temp_ts .shape [0 ]))
276276 else :
277- latent_model_input = jnp .concatenate ([latents_input , condition ], axis = 1 )
277+ latent_model_input = jnp .concatenate ([latents_input , condition ], axis = - 1 )
278278 timestep = jnp .broadcast_to (t , latents_input .shape [0 ])
279279
280280
281281 use_high_noise = jnp .greater_equal (t , boundary )
282282
283283 def high_noise_branch (operands ):
284284 latents_input , ts_input , pe_input , ie_input = operands
285- return transformer_forward_pass (
285+ latents_input = jnp .transpose (latents_input , (0 , 4 , 1 , 2 , 3 ))
286+ noise_pred , latents_out = transformer_forward_pass (
286287 high_noise_graphdef , high_noise_state , high_noise_rest ,
287288 latents_input , ts_input , pe_input ,
288289 do_classifier_free_guidance = do_classifier_free_guidance , guidance_scale = guidance_scale ,
289290 encoder_hidden_states_image = ie_input
290291 )
292+ noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
293+ return noise_pred , latents_out
291294
292295 def low_noise_branch (operands ):
293296 latents_input , ts_input , pe_input , ie_input = operands
294- return transformer_forward_pass (
297+ latents_input = jnp .transpose (latents_input , (0 , 4 , 1 , 2 , 3 ))
298+ noise_pred , latents_out = transformer_forward_pass (
295299 low_noise_graphdef , low_noise_state , low_noise_rest ,
296300 latents_input , ts_input , pe_input ,
297301 do_classifier_free_guidance = do_classifier_free_guidance , guidance_scale = guidance_scale_2 ,
298302 encoder_hidden_states_image = ie_input
299303 )
304+ noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
305+ return noise_pred , latents_out
300306
301307 noise_pred , latents = jax .lax .cond (
302308 use_high_noise ,
0 commit comments