@@ -360,35 +360,52 @@ def multistep_uni_p_bh_update(
360360 raise NotImplementedError ("Nested `solver_p` is not implemented in JAX version yet." )
361361
362362 m0 = state .model_outputs [self .config .solver_order - 1 ] # Most recent stored converted model output
363+ check_nan_jit (m0 , "P m0" , step )
363364 x = sample
365+ check_nan_jit (x , "P sample" , step )
364366
365367 sigma_t_val , sigma_s0_val = (
366368 state .sigmas [state .step_index + 1 ],
367369 state .sigmas [state .step_index ],
368370 )
371+ check_nan_jit (sigma_t_val , "P sigma_t_val" , step )
372+ check_nan_jit (sigma_s0_val , "P sigma_s0_val" , step )
373+
369374
370375 alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma_t_val )
376+ check_nan_jit (alpha_t , "P alpha_t" , step )
377+ check_nan_jit (sigma_t , "P sigma_t" , step )
378+
371379 alpha_s0 , sigma_s0 = self ._sigma_to_alpha_sigma_t (sigma_s0_val )
380+ check_nan_jit (alpha_s0 , "P alpha_s0" , step )
381+ check_nan_jit (sigma_s0 , "P sigma_s0" , step )
372382
373383 lambda_t = jnp .log (alpha_t + 1e-10 ) - jnp .log (sigma_t + 1e-10 )
384+ check_nan_jit (lambda_t , "P lambda_t" , step )
374385 lambda_s0 = jnp .log (alpha_s0 + 1e-10 ) - jnp .log (sigma_s0 + 1e-10 )
386+ check_nan_jit (lambda_s0 , "P lambda_s0" , step )
375387
376388 h = lambda_t - lambda_s0
377- check_nan_jit (h , "predictor h" , step )
389+ check_nan_jit (h , "P h" , step )
378390
379391 def rk_d1_loop_body (i , carry ):
380392 # Loop from i = 0 to order-2
381393 rks , D1s = carry
382394 history_idx = self .config .solver_order - 2 - i
383395 mi = state .model_outputs [history_idx ]
396+ check_nan_jit (mi , f"P rk_d1 mi[{ i } ]" , step )
384397 si_val = state .timestep_list [history_idx ]
385398
386399 alpha_si , sigma_si = self ._sigma_to_alpha_sigma_t (state .sigmas [self .index_for_timestep (state , si_val )])
400+ check_nan_jit (alpha_si , f"P rk_d1 alpha_si[{ i } ]" , step )
401+ check_nan_jit (sigma_si , f"P rk_d1 sigma_si[{ i } ]" , step )
387402 lambda_si = jnp .log (alpha_si + 1e-10 ) - jnp .log (sigma_si + 1e-10 )
403+ check_nan_jit (lambda_si , f"P rk_d1 lambda_si[{ i } ]" , step )
388404
389405 rk = (lambda_si - lambda_s0 ) / h
406+ check_nan_jit (rk , f"P rk_d1 rk[{ i } ]" , step )
390407 Di = (mi - m0 ) / rk
391- check_nan_jit (Di , f"predictor Di[{ i } ]" , step )
408+ check_nan_jit (Di , f"P rk_d1 Di[{ i } ]" , step )
392409
393410 rks = rks .at [i ].set (rk )
394411 D1s = D1s .at [i ].set (Di )
@@ -400,27 +417,37 @@ def rk_d1_loop_body(i, carry):
400417 # Dummy D1s array. It will not be used if order == 1
401418 D1s_init = jnp .zeros ((1 , * m0 .shape ), dtype = m0 .dtype )
402419 rks , D1s = jax .lax .fori_loop (0 , order - 1 , rk_d1_loop_body , (rks_init , D1s_init ))
420+ check_nan_jit (rks , "P rks after loop" , step )
421+ check_nan_jit (D1s , "P D1s after loop" , step )
403422 rks = rks .at [order - 1 ].set (1.0 )
423+ check_nan_jit (rks , "P rks final" , step )
404424
405425 hh = - h if self .config .predict_x0 else h
426+ check_nan_jit (hh , "P hh" , step )
406427 h_phi_1 = jnp .expm1 (hh )
428+ check_nan_jit (h_phi_1 , "P h_phi_1" , step )
407429
408430 if self .config .solver_type == "bh1" :
409431 B_h = hh
410432 elif self .config .solver_type == "bh2" :
411433 B_h = jnp .expm1 (hh )
412434 else :
413435 raise NotImplementedError ()
436+ check_nan_jit (B_h , "P B_h" , step )
414437
415438 def rb_loop_body (i , carry ):
416439 R , b , current_h_phi_k , factorial_val = carry
440+ check_nan_jit (current_h_phi_k , f"P rb_loop[{ i } ] current_h_phi_k IN" , step )
441+ check_nan_jit (factorial_val , f"P rb_loop[{ i } ] factorial_val IN" , step )
417442 R = R .at [i ].set (jnp .power (rks , i ))
418443 b = b .at [i ].set (current_h_phi_k * factorial_val / B_h )
419444
420445 def update_fn (vals ):
421446 _h_phi_k , _fac = vals
422447 next_fac = _fac * (i + 2 )
448+ check_nan_jit (next_fac , f"P rb_loop[{ i } ] next_fac" , step )
423449 next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac
450+ check_nan_jit (next_h_phi_k , f"P rb_loop[{ i } ] next_h_phi_k" , step )
424451 return next_h_phi_k , next_fac
425452
426453 current_h_phi_k , factorial_val = jax .lax .cond (
@@ -434,11 +461,16 @@ def update_fn(vals):
434461 R_init = jnp .zeros ((self .config .solver_order , self .config .solver_order ), dtype = h .dtype )
435462 b_init = jnp .zeros (self .config .solver_order , dtype = h .dtype )
436463 init_h_phi_k = h_phi_1 / hh - 1.0
464+ check_nan_jit (init_h_phi_k , "P init_h_phi_k" , step )
437465 init_factorial = 1.0
438466 R , b , _ , _ = jax .lax .fori_loop (0 , order , rb_loop_body , (R_init , b_init , init_h_phi_k , init_factorial ))
467+ check_nan_jit (R , "P R after loop" , step )
468+ check_nan_jit (b , "P b after loop" , step )
469+
439470
440471 if len (D1s ) > 0 :
441472 D1s = jnp .stack (D1s , axis = 1 ) # Resulting shape (B, K, C, H, W)
473+ check_nan_jit (D1s , "P D1s_stacked" , step )
442474
443475 def solve_for_rhos_p (R_mat , b_vec , current_order ):
444476 # Create a mask for the top-left (current_order - 1) x (current_order - 1) sub-matrix
@@ -453,9 +485,12 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
453485 jnp .eye (mask_size , dtype = R_mat .dtype ),
454486 )
455487 b_safe = jnp .where (mask , b_vec [:mask_size ], 0.0 )
488+ check_nan_jit (R_safe , "P solve R_safe" , step )
489+ check_nan_jit (b_safe , "P solve b_safe" , step )
456490
457491 # Solve the system and mask the result
458492 solved_rhos = jnp .linalg .solve (R_safe , b_safe )
493+ check_nan_jit (solved_rhos , "P solve solved_rhos" , step )
459494 return jnp .where (mask , solved_rhos , 0.0 )
460495
461496 # Handle the special case for order == 2
@@ -467,9 +502,11 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
467502
468503 # Get the result for the general case
469504 rhos_p_general = solve_for_rhos_p (R , b , order )
505+ check_nan_jit (rhos_p_general , "P rhos_p_general" , step )
470506
471507 # Select the appropriate result based on the order
472508 rhos_p = jnp .where (order == 2 , rhos_p_order2 , rhos_p_general )
509+ check_nan_jit (rhos_p , "P rhos_p" , step )
473510
474511 pred_res = jax .lax .cond (
475512 order > 1 ,
@@ -478,14 +515,21 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
478515 lambda _ : jnp .zeros_like (x ),
479516 operand = None ,
480517 )
518+ check_nan_jit (pred_res , "P pred_res" , step )
481519
482520 if self .config .predict_x0 :
483521 x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
484- x_t = x_t_ - alpha_t * B_h * pred_res
522+ check_nan_jit (x_t_ , "P x_t_ term" , step )
523+ term2 = alpha_t * B_h * pred_res
524+ check_nan_jit (term2 , "P term2" , step )
525+ x_t = x_t_ - term2
485526 else : # Predict epsilon
486527 x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
487- x_t = x_t_ - sigma_t * B_h * pred_res
488- check_nan_jit (x_t , "predictor x_t" , step )
528+ check_nan_jit (x_t_ , "P x_t_ term eps" , step )
529+ term2 = sigma_t * B_h * pred_res
530+ check_nan_jit (term2 , "P term2 eps" , step )
531+ x_t = x_t_ - term2
532+ check_nan_jit (x_t , "P final x_t" , step )
489533 return x_t .astype (x .dtype )
490534
491535 def multistep_uni_c_bh_update (
0 commit comments