@@ -257,15 +257,6 @@ def __call__(
257257 max_logging .log (f"[DEBUG CALL] Decoded video type: { type (decoded_video )} " )
258258 return decoded_video
259259
260- def check_nan_jit (tensor : jax .Array , name : str , step : jax .Array ):
261- if tensor is None :
262- return
263-
264- has_nans = jnp .isnan (tensor ).any ()
265- has_infs = jnp .isinf (tensor ).any ()
266- jax .debug .print (f"[DEBUG JIT { jax .process_index ()} ] Step: {{step}} - { name } : "
267- "Shape: {shape}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}" ,
268- step = step , shape = tensor .shape , has_nans_val = has_nans , has_infs_val = has_infs )
269260
270261def run_inference_2_1_i2v (
271262 graphdef , sharded_state , rest_of_state ,
@@ -300,23 +291,16 @@ def loop_body(step, vals):
300291 rng , timestep_rng = jax .random .split (rng )
301292 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
302293
303- check_nan_jit (latents , "latents_prev at loop start" , step )
304-
305294 latents_input = latents
306295 if do_classifier_free_guidance :
307296 latents_input = jnp .concatenate ([latents , latents ], axis = 0 )
308- check_nan_jit (latents_input , "latents_input after CFG concat" , step )
309297
310298 latent_model_input = jnp .concatenate ([latents_input , condition ], axis = - 1 )
311- check_nan_jit (latent_model_input , "latent_model_input after cond concat" , step )
312299 timestep = jnp .broadcast_to (t , latents_input .shape [0 ])
313300 latent_model_input = jnp .transpose (latent_model_input , (0 , 4 , 1 , 2 , 3 ))
314- check_nan_jit (latent_model_input , "latent_model_input for transformer" , step )
315301
316302 prompt_embeds_input = prompt_embeds
317303 image_embeds_input = image_embeds
318- check_nan_jit (prompt_embeds_input , "prompt_embeds_input for transformer" , step )
319- check_nan_jit (image_embeds_input , "image_embeds_input for transformer" , step )
320304
321305
322306 noise_pred , _ = transformer_forward_pass (
@@ -326,14 +310,9 @@ def loop_body(step, vals):
326310 guidance_scale = guidance_scale ,
327311 encoder_hidden_states_image = image_embeds_input ,
328312 )
329- check_nan_jit (noise_pred , "noise_pred_bcthw from transformer" , step )
330313 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
331- check_nan_jit (noise_pred , "noise_pred after transpose" , step )
332-
333314 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
334- check_nan_jit (latents , "latents_next after scheduler" , step )
335315 latents = latents .astype (original_dtype )
336- check_nan_jit (latents , "latents_next after dtype cast" , step )
337316 return latents , scheduler_state , rng
338317
339318 max_logging .log (f"Running fori_loop for { num_inference_steps } steps." )
0 commit comments