3030 add_noise_common ,
3131)
3232
33- def check_nan_jit (tensor : jax .Array , name : str , step : jax .Array ):
34- if tensor is None :
35- return
36- has_nans = jnp .isnan (tensor ).any ()
37- has_infs = jnp .isinf (tensor ).any ()
38- if step is None :
39- step = - 1
40-
41- # Print the actual dtype of the tensor's data
42- jax .debug .print (f"[DEBUG SCHEDULER { jax .process_index ()} ] Step: {{step}} - { name } : "
43- "Shape: {shape}, tensor.dtype: {dtype}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}" ,
44- step = step , shape = tensor .shape , dtype = tensor .dtype , has_nans_val = has_nans , has_infs_val = has_infs )
4533
4634@flax .struct .dataclass
4735class UniPCMultistepSchedulerState :
@@ -297,18 +285,14 @@ def convert_model_output(
297285 state : UniPCMultistepSchedulerState ,
298286 model_output : jnp .ndarray ,
299287 sample : jnp .ndarray ,
300- step : jax .Array ,
301288 ) -> jnp .ndarray :
302289 """
303290 Converts the model output based on the prediction type and current state.
304291 """
305292 sigma = state .sigmas [state .step_index ] # Current sigma
306- check_nan_jit (sigma , "convert_model_output sigma" , step )
307293
308294 # Ensure sigma is a JAX array for _sigma_to_alpha_sigma_t
309295 alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma )
310- check_nan_jit (alpha_t , "convert_model_output alpha_t" , step )
311- check_nan_jit (sigma_t , "convert_model_output sigma_t" , step )
312296
313297 if self .config .predict_x0 :
314298 if self .config .prediction_type == "epsilon" :
@@ -326,7 +310,6 @@ def convert_model_output(
326310 f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample`, "
327311 "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
328312 )
329- check_nan_jit (x0_pred , "convert_model_output x0_pred" , step )
330313
331314 if self .config .thresholding :
332315 raise NotImplementedError ("Dynamic thresholding isn't implemented." )
@@ -353,7 +336,6 @@ def multistep_uni_p_bh_update(
353336 model_output : jnp .ndarray ,
354337 sample : jnp .ndarray ,
355338 order : int ,
356- step : jax .Array ,
357339 ) -> jnp .ndarray :
358340 """
359341 One step for the UniP (B(h) version) - the Predictor.
@@ -362,52 +344,33 @@ def multistep_uni_p_bh_update(
362344 raise NotImplementedError ("Nested `solver_p` is not implemented in JAX version yet." )
363345
364346 m0 = state .model_outputs [self .config .solver_order - 1 ] # Most recent stored converted model output
365- check_nan_jit (m0 , "P m0" , step )
366347 x = sample
367- check_nan_jit (x , "P sample" , step )
368348
369349 sigma_t_val , sigma_s0_val = (
370350 state .sigmas [state .step_index + 1 ],
371351 state .sigmas [state .step_index ],
372352 )
373- check_nan_jit (sigma_t_val , "P sigma_t_val" , step )
374- check_nan_jit (sigma_s0_val , "P sigma_s0_val" , step )
375-
376353
377354 alpha_t , sigma_t = self ._sigma_to_alpha_sigma_t (sigma_t_val )
378- check_nan_jit (alpha_t , "P alpha_t" , step )
379- check_nan_jit (sigma_t , "P sigma_t" , step )
380-
381355 alpha_s0 , sigma_s0 = self ._sigma_to_alpha_sigma_t (sigma_s0_val )
382- check_nan_jit (alpha_s0 , "P alpha_s0" , step )
383- check_nan_jit (sigma_s0 , "P sigma_s0" , step )
384356
385357 lambda_t = jnp .log (alpha_t + 1e-10 ) - jnp .log (sigma_t + 1e-10 )
386- check_nan_jit (lambda_t , "P lambda_t" , step )
387358 lambda_s0 = jnp .log (alpha_s0 + 1e-10 ) - jnp .log (sigma_s0 + 1e-10 )
388- check_nan_jit (lambda_s0 , "P lambda_s0" , step )
389359
390360 h = lambda_t - lambda_s0
391- check_nan_jit (h , "P h" , step )
392361
393362 def rk_d1_loop_body (i , carry ):
394363 # Loop from i = 0 to order-2
395364 rks , D1s = carry
396365 history_idx = self .config .solver_order - 2 - i
397366 mi = state .model_outputs [history_idx ]
398- check_nan_jit (mi , f"P rk_d1 mi[{ i } ]" , step )
399367 si_val = state .timestep_list [history_idx ]
400368
401369 alpha_si , sigma_si = self ._sigma_to_alpha_sigma_t (state .sigmas [self .index_for_timestep (state , si_val )])
402- check_nan_jit (alpha_si , f"P rk_d1 alpha_si[{ i } ]" , step )
403- check_nan_jit (sigma_si , f"P rk_d1 sigma_si[{ i } ]" , step )
404370 lambda_si = jnp .log (alpha_si + 1e-10 ) - jnp .log (sigma_si + 1e-10 )
405- check_nan_jit (lambda_si , f"P rk_d1 lambda_si[{ i } ]" , step )
406371
407372 rk = (lambda_si - lambda_s0 ) / h
408- check_nan_jit (rk , f"P rk_d1 rk[{ i } ]" , step )
409373 Di = (mi - m0 ) / rk
410- check_nan_jit (Di , f"P rk_d1 Di[{ i } ]" , step )
411374
412375 rks = rks .at [i ].set (rk )
413376 D1s = D1s .at [i ].set (Di )
@@ -419,37 +382,27 @@ def rk_d1_loop_body(i, carry):
419382 # Dummy D1s array. It will not be used if order == 1
420383 D1s_init = jnp .zeros ((1 , * m0 .shape ), dtype = m0 .dtype )
421384 rks , D1s = jax .lax .fori_loop (0 , order - 1 , rk_d1_loop_body , (rks_init , D1s_init ))
422- check_nan_jit (rks , "P rks after loop" , step )
423- check_nan_jit (D1s , "P D1s after loop" , step )
424385 rks = rks .at [order - 1 ].set (1.0 )
425- check_nan_jit (rks , "P rks final" , step )
426386
427387 hh = - h if self .config .predict_x0 else h
428- check_nan_jit (hh , "P hh" , step )
429388 h_phi_1 = jnp .expm1 (hh )
430- check_nan_jit (h_phi_1 , "P h_phi_1" , step )
431389
432390 if self .config .solver_type == "bh1" :
433391 B_h = hh
434392 elif self .config .solver_type == "bh2" :
435393 B_h = jnp .expm1 (hh )
436394 else :
437395 raise NotImplementedError ()
438- check_nan_jit (B_h , "P B_h" , step )
439396
440397 def rb_loop_body (i , carry ):
441398 R , b , current_h_phi_k , factorial_val = carry
442- check_nan_jit (current_h_phi_k , f"P rb_loop[{ i } ] current_h_phi_k IN" , step )
443- check_nan_jit (factorial_val , f"P rb_loop[{ i } ] factorial_val IN" , step )
444399 R = R .at [i ].set (jnp .power (rks , i ))
445400 b = b .at [i ].set (current_h_phi_k * factorial_val / B_h )
446401
447402 def update_fn (vals ):
448403 _h_phi_k , _fac = vals
449404 next_fac = _fac * (i + 2 )
450- check_nan_jit (next_fac , f"P rb_loop[{ i } ] next_fac" , step )
451405 next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac
452- check_nan_jit (next_h_phi_k , f"P rb_loop[{ i } ] next_h_phi_k" , step )
453406 return next_h_phi_k , next_fac
454407
455408 current_h_phi_k , factorial_val = jax .lax .cond (
@@ -463,16 +416,11 @@ def update_fn(vals):
463416 R_init = jnp .zeros ((self .config .solver_order , self .config .solver_order ), dtype = h .dtype )
464417 b_init = jnp .zeros (self .config .solver_order , dtype = h .dtype )
465418 init_h_phi_k = h_phi_1 / hh - 1.0
466- check_nan_jit (init_h_phi_k , "P init_h_phi_k" , step )
467419 init_factorial = 1.0
468420 R , b , _ , _ = jax .lax .fori_loop (0 , order , rb_loop_body , (R_init , b_init , init_h_phi_k , init_factorial ))
469- check_nan_jit (R , "P R after loop" , step )
470- check_nan_jit (b , "P b after loop" , step )
471-
472421
473422 if len (D1s ) > 0 :
474423 D1s = jnp .stack (D1s , axis = 1 ) # Resulting shape (B, K, C, H, W)
475- check_nan_jit (D1s , "P D1s_stacked" , step )
476424
477425 def solve_for_rhos_p (R_mat , b_vec , current_order ):
478426 # Create a mask for the top-left (current_order - 1) x (current_order - 1) sub-matrix
@@ -487,12 +435,9 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
487435 jnp .eye (mask_size , dtype = R_mat .dtype ),
488436 )
489437 b_safe = jnp .where (mask , b_vec [:mask_size ], 0.0 )
490- check_nan_jit (R_safe , "P solve R_safe" , step )
491- check_nan_jit (b_safe , "P solve b_safe" , step )
492438
493439 # Solve the system and mask the result
494440 solved_rhos = jnp .linalg .solve (R_safe , b_safe )
495- check_nan_jit (solved_rhos , "P solve solved_rhos" , step )
496441 return jnp .where (mask , solved_rhos , 0.0 )
497442
498443 # Handle the special case for order == 2
@@ -504,11 +449,9 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
504449
505450 # Get the result for the general case
506451 rhos_p_general = solve_for_rhos_p (R , b , order )
507- check_nan_jit (rhos_p_general , "P rhos_p_general" , step )
508452
509453 # Select the appropriate result based on the order
510454 rhos_p = jnp .where (order == 2 , rhos_p_order2 , rhos_p_general )
511- check_nan_jit (rhos_p , "P rhos_p" , step )
512455
513456 pred_res = jax .lax .cond (
514457 order > 1 ,
@@ -517,21 +460,14 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
517460 lambda _ : jnp .zeros_like (x ),
518461 operand = None ,
519462 )
520- check_nan_jit (pred_res , "P pred_res" , step )
521463
522464 if self .config .predict_x0 :
523- x_t_ = sigma_t / (sigma_s0 ) * x - alpha_t * h_phi_1 * m0
524- check_nan_jit (x_t_ , "P x_t_ term" , step )
525- term2 = alpha_t * B_h * pred_res
526- check_nan_jit (term2 , "P term2" , step )
527- x_t = x_t_ - term2
465+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
466+ x_t = x_t_ - alpha_t * B_h * pred_res
528467 else : # Predict epsilon
529- x_t_ = alpha_t / (alpha_s0 ) * x - sigma_t * h_phi_1 * m0
530- check_nan_jit (x_t_ , "P x_t_ term eps" , step )
531- term2 = sigma_t * B_h * pred_res
532- check_nan_jit (term2 , "P term2 eps" , step )
533- x_t = x_t_ - term2
534- check_nan_jit (x_t , "P final x_t" , step )
468+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
469+ x_t = x_t_ - sigma_t * B_h * pred_res
470+
535471 return x_t .astype (x .dtype )
536472
537473 def multistep_uni_c_bh_update (
@@ -541,7 +477,6 @@ def multistep_uni_c_bh_update(
541477 last_sample : jnp .ndarray , # Sample after predictor `x_{t-1}`
542478 this_sample : jnp .ndarray , # Sample before corrector `x_t` (after predictor step)
543479 order : int ,
544- step : jax .Array ,
545480 ) -> jnp .ndarray :
546481 """
547482 One step for the UniC (B(h) version) - the Corrector.
@@ -685,8 +620,7 @@ def solve_for_rhos(R_mat, b_vec, current_order):
685620 else :
686621 x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
687622 x_t = x_t_ - sigma_t * B_h * (corr_res + final_rho * D1_t )
688-
689- check_nan_jit (x_t , "corrector x_t" , step )
623+
690624 return x_t .astype (x .dtype )
691625
692626 def index_for_timestep (
@@ -740,10 +674,6 @@ def step(
740674 Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
741675 the multistep UniPC.
742676 """
743- step_val = state .step_index # For debug, might be None initially
744-
745- check_nan_jit (model_output , "step input model_output" , step_val )
746- check_nan_jit (sample , "step input sample" , step_val )
747677
748678 sample = sample .astype (jnp .float32 )
749679
@@ -755,7 +685,6 @@ def step(
755685 # Initialize step_index if it's the first step
756686 if state .step_index is None :
757687 state = self ._init_step_index (state , timestep_scalar )
758- step_val = state .step_index
759688
760689 # Determine if corrector should be used
761690 use_corrector = (
@@ -765,8 +694,7 @@ def step(
765694 )
766695
767696 # Convert model_output (noise/v_pred) to x0_pred or epsilon_pred, based on prediction_type
768- model_output_for_history = self .convert_model_output (state , model_output , sample , step_val )
769- check_nan_jit (model_output_for_history , "model_output_for_history" , step_val )
697+ model_output_for_history = self .convert_model_output (state , model_output , sample )
770698
771699 # Apply corrector if applicable
772700 sample = jax .lax .cond (
@@ -777,11 +705,9 @@ def step(
777705 last_sample = state .last_sample ,
778706 this_sample = sample ,
779707 order = state .this_order ,
780- step = step_val
781708 ),
782709 lambda : sample ,
783710 )
784- check_nan_jit (sample , "sample_corrected" , step_val )
785711
786712 # Update history buffers (model_outputs and timestep_list)
787713 # Shift existing elements to the left and add new one at the end.
@@ -832,7 +758,6 @@ def non_step_idx0_branch():
832758 model_output = model_output ,
833759 sample = sample ,
834760 order = state .this_order ,
835- step = step_val ,
836761 )
837762
838763 # Update lower_order_nums for warmup
@@ -869,16 +794,14 @@ def add_noise(
869794 def _sigma_to_alpha_sigma_t (self , sigma ):
870795 eps = 1e-10
871796 if self .config .use_flow_sigmas :
872- alpha_t = 1 - sigma
873- sigma_t = sigma
797+ alpha_t = jnp . maximum ( 1 - sigma , eps )
798+ sigma_t = jnp . maximum ( sigma , eps )
874799 else :
875- sigma_clamped = jnp .maximum (sigma , eps )
876- alpha_t = 1 / ((sigma_clamped ** 2 + 1 ) ** 0.5 )
877- sigma_t = sigma_clamped * alpha_t
878- alpha_t = jnp .maximum (alpha_t , eps )
879- sigma_t = jnp .maximum (sigma_t , eps )
800+ sigma_safe = jnp .maximum (sigma , eps )
801+ alpha_t = 1 / ((sigma_safe ** 2 + 1 ) ** 0.5 )
802+ sigma_t = sigma_safe * alpha_t
880803
881804 return alpha_t , sigma_t
882805
883806 def __len__ (self ) -> int :
884- return self .config .num_train_timesteps
807+ return self .config .num_train_timesteps
0 commit comments