@@ -256,6 +256,16 @@ def __call__(
256256 max_logging .log (f"[DEBUG CALL] Decoded video type: { type (decoded_video )} " )
257257 return decoded_video
258258
259+ def check_nan_jit (tensor : jax .Array , name : str , step : jax .Array ):
260+ if tensor is None :
261+ return
262+
263+ has_nans = jnp .isnan (tensor ).any ()
264+ has_infs = jnp .isinf (tensor ).any ()
265+ jax .debug .print (f"[DEBUG JIT { jax .process_index ()} ] Step: {{step}} - { name } : "
266+ "Shape: {shape}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}" ,
267+ step = step , shape = tensor .shape , has_nans_val = has_nans , has_infs_val = has_infs )
268+
259269def run_inference_2_1_i2v (
260270 graphdef , sharded_state , rest_of_state ,
261271 latents : jnp .array ,
@@ -289,16 +299,24 @@ def loop_body(step, vals):
289299 rng , timestep_rng = jax .random .split (rng )
290300 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
291301
302+ check_nan_jit (latents , "latents_prev at loop start" , step )
303+
292304 latents_input = latents
293305 if do_classifier_free_guidance :
294306 latents_input = jnp .concatenate ([latents , latents ], axis = 0 )
307+ check_nan_jit (latents_input , "latents_input after CFG concat" , step )
295308
296309 latent_model_input = jnp .concatenate ([latents_input , condition ], axis = - 1 )
310+ check_nan_jit (latent_model_input , "latent_model_input after cond concat" , step )
297311 timestep = jnp .broadcast_to (t , latents_input .shape [0 ])
298312 latent_model_input = jnp .transpose (latent_model_input , (0 , 4 , 1 , 2 , 3 ))
313+ check_nan_jit (latent_model_input , "latent_model_input for transformer" , step )
299314
300315 prompt_embeds_input = prompt_embeds
301316 image_embeds_input = image_embeds
317+ check_nan_jit (prompt_embeds_input , "prompt_embeds_input for transformer" , step )
318+ check_nan_jit (image_embeds_input , "image_embeds_input for transformer" , step )
319+
302320
303321 noise_pred , _ = transformer_forward_pass (
304322 graphdef , sharded_state , rest_of_state ,
@@ -307,11 +325,17 @@ def loop_body(step, vals):
307325 guidance_scale = guidance_scale ,
308326 encoder_hidden_states_image = image_embeds_input ,
309327 )
328+ check_nan_jit (noise_pred , "noise_pred_bcthw from transformer" , step )
310329 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
330+ check_nan_jit (noise_pred , "noise_pred after transpose" , step )
311331
312332 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
333+ check_nan_jit (latents , "latents_next after scheduler" , step )
313334 latents = latents .astype (original_dtype )
335+ check_nan_jit (latents , "latents_next after dtype cast" , step )
314336 return latents , scheduler_state , rng
315337
338+ max_logging .log (f"Running fori_loop for { num_inference_steps } steps." )
316339 latents , _ , _ = jax .lax .fori_loop (0 , num_inference_steps , loop_body , (latents , scheduler_state , rng ))
340+ max_logging .log ("Finished fori_loop." )
317341 return latents
0 commit comments