Skip to content

Commit d05ca47

Browse files
committed
trying fix for NaN
1 parent 3140d45 commit d05ca47

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,13 +518,13 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
518518
check_nan_jit(pred_res, "P pred_res", step)
519519

520520
if self.config.predict_x0:
521-
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
521+
x_t_ = sigma_t / (sigma_s0 + 1e-8) * x - alpha_t * h_phi_1 * m0
522522
check_nan_jit(x_t_, "P x_t_ term", step)
523523
term2 = alpha_t * B_h * pred_res
524524
check_nan_jit(term2, "P term2", step)
525525
x_t = x_t_ - term2
526526
else: # Predict epsilon
527-
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
527+
x_t_ = alpha_t / (alpha_s0 + 1e-8) * x - sigma_t * h_phi_1 * m0
528528
check_nan_jit(x_t_, "P x_t_ term eps", step)
529529
term2 = sigma_t * B_h * pred_res
530530
check_nan_jit(term2, "P term2 eps", step)

0 commit comments

Comments
 (0)