Skip to content

Commit 3140d45

Browse files
committed
debug for NaNs
1 parent 9474c15 commit 3140d45

1 file changed

Lines changed: 49 additions & 5 deletions

File tree

src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -360,35 +360,52 @@ def multistep_uni_p_bh_update(
360360
raise NotImplementedError("Nested `solver_p` is not implemented in JAX version yet.")
361361

362362
m0 = state.model_outputs[self.config.solver_order - 1] # Most recent stored converted model output
363+
check_nan_jit(m0, "P m0", step)
363364
x = sample
365+
check_nan_jit(x, "P sample", step)
364366

365367
sigma_t_val, sigma_s0_val = (
366368
state.sigmas[state.step_index + 1],
367369
state.sigmas[state.step_index],
368370
)
371+
check_nan_jit(sigma_t_val, "P sigma_t_val", step)
372+
check_nan_jit(sigma_s0_val, "P sigma_s0_val", step)
373+
369374

370375
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t_val)
376+
check_nan_jit(alpha_t, "P alpha_t", step)
377+
check_nan_jit(sigma_t, "P sigma_t", step)
378+
371379
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0_val)
380+
check_nan_jit(alpha_s0, "P alpha_s0", step)
381+
check_nan_jit(sigma_s0, "P sigma_s0", step)
372382

373383
lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10)
384+
check_nan_jit(lambda_t, "P lambda_t", step)
374385
lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10)
386+
check_nan_jit(lambda_s0, "P lambda_s0", step)
375387

376388
h = lambda_t - lambda_s0
377-
check_nan_jit(h, "predictor h", step)
389+
check_nan_jit(h, "P h", step)
378390

379391
def rk_d1_loop_body(i, carry):
380392
# Loop from i = 0 to order-2
381393
rks, D1s = carry
382394
history_idx = self.config.solver_order - 2 - i
383395
mi = state.model_outputs[history_idx]
396+
check_nan_jit(mi, f"P rk_d1 mi[{i}]", step)
384397
si_val = state.timestep_list[history_idx]
385398

386399
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(state.sigmas[self.index_for_timestep(state, si_val)])
400+
check_nan_jit(alpha_si, f"P rk_d1 alpha_si[{i}]", step)
401+
check_nan_jit(sigma_si, f"P rk_d1 sigma_si[{i}]", step)
387402
lambda_si = jnp.log(alpha_si + 1e-10) - jnp.log(sigma_si + 1e-10)
403+
check_nan_jit(lambda_si, f"P rk_d1 lambda_si[{i}]", step)
388404

389405
rk = (lambda_si - lambda_s0) / h
406+
check_nan_jit(rk, f"P rk_d1 rk[{i}]", step)
390407
Di = (mi - m0) / rk
391-
check_nan_jit(Di, f"predictor Di[{i}]", step)
408+
check_nan_jit(Di, f"P rk_d1 Di[{i}]", step)
392409

393410
rks = rks.at[i].set(rk)
394411
D1s = D1s.at[i].set(Di)
@@ -400,27 +417,37 @@ def rk_d1_loop_body(i, carry):
400417
# Dummy D1s array. It will not be used if order == 1
401418
D1s_init = jnp.zeros((1, *m0.shape), dtype=m0.dtype)
402419
rks, D1s = jax.lax.fori_loop(0, order - 1, rk_d1_loop_body, (rks_init, D1s_init))
420+
check_nan_jit(rks, "P rks after loop", step)
421+
check_nan_jit(D1s, "P D1s after loop", step)
403422
rks = rks.at[order - 1].set(1.0)
423+
check_nan_jit(rks, "P rks final", step)
404424

405425
hh = -h if self.config.predict_x0 else h
426+
check_nan_jit(hh, "P hh", step)
406427
h_phi_1 = jnp.expm1(hh)
428+
check_nan_jit(h_phi_1, "P h_phi_1", step)
407429

408430
if self.config.solver_type == "bh1":
409431
B_h = hh
410432
elif self.config.solver_type == "bh2":
411433
B_h = jnp.expm1(hh)
412434
else:
413435
raise NotImplementedError()
436+
check_nan_jit(B_h, "P B_h", step)
414437

415438
def rb_loop_body(i, carry):
416439
R, b, current_h_phi_k, factorial_val = carry
440+
check_nan_jit(current_h_phi_k, f"P rb_loop[{i}] current_h_phi_k IN", step)
441+
check_nan_jit(factorial_val, f"P rb_loop[{i}] factorial_val IN", step)
417442
R = R.at[i].set(jnp.power(rks, i))
418443
b = b.at[i].set(current_h_phi_k * factorial_val / B_h)
419444

420445
def update_fn(vals):
421446
_h_phi_k, _fac = vals
422447
next_fac = _fac * (i + 2)
448+
check_nan_jit(next_fac, f"P rb_loop[{i}] next_fac", step)
423449
next_h_phi_k = _h_phi_k / hh - 1.0 / next_fac
450+
check_nan_jit(next_h_phi_k, f"P rb_loop[{i}] next_h_phi_k", step)
424451
return next_h_phi_k, next_fac
425452

426453
current_h_phi_k, factorial_val = jax.lax.cond(
@@ -434,11 +461,16 @@ def update_fn(vals):
434461
R_init = jnp.zeros((self.config.solver_order, self.config.solver_order), dtype=h.dtype)
435462
b_init = jnp.zeros(self.config.solver_order, dtype=h.dtype)
436463
init_h_phi_k = h_phi_1 / hh - 1.0
464+
check_nan_jit(init_h_phi_k, "P init_h_phi_k", step)
437465
init_factorial = 1.0
438466
R, b, _, _ = jax.lax.fori_loop(0, order, rb_loop_body, (R_init, b_init, init_h_phi_k, init_factorial))
467+
check_nan_jit(R, "P R after loop", step)
468+
check_nan_jit(b, "P b after loop", step)
469+
439470

440471
if len(D1s) > 0:
441472
D1s = jnp.stack(D1s, axis=1) # Resulting shape (B, K, C, H, W)
473+
check_nan_jit(D1s, "P D1s_stacked", step)
442474

443475
def solve_for_rhos_p(R_mat, b_vec, current_order):
444476
# Create a mask for the top-left (current_order - 1) x (current_order - 1) sub-matrix
@@ -453,9 +485,12 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
453485
jnp.eye(mask_size, dtype=R_mat.dtype),
454486
)
455487
b_safe = jnp.where(mask, b_vec[:mask_size], 0.0)
488+
check_nan_jit(R_safe, "P solve R_safe", step)
489+
check_nan_jit(b_safe, "P solve b_safe", step)
456490

457491
# Solve the system and mask the result
458492
solved_rhos = jnp.linalg.solve(R_safe, b_safe)
493+
check_nan_jit(solved_rhos, "P solve solved_rhos", step)
459494
return jnp.where(mask, solved_rhos, 0.0)
460495

461496
# Handle the special case for order == 2
@@ -467,9 +502,11 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
467502

468503
# Get the result for the general case
469504
rhos_p_general = solve_for_rhos_p(R, b, order)
505+
check_nan_jit(rhos_p_general, "P rhos_p_general", step)
470506

471507
# Select the appropriate result based on the order
472508
rhos_p = jnp.where(order == 2, rhos_p_order2, rhos_p_general)
509+
check_nan_jit(rhos_p, "P rhos_p", step)
473510

474511
pred_res = jax.lax.cond(
475512
order > 1,
@@ -478,14 +515,21 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
478515
lambda _: jnp.zeros_like(x),
479516
operand=None,
480517
)
518+
check_nan_jit(pred_res, "P pred_res", step)
481519

482520
if self.config.predict_x0:
483521
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
484-
x_t = x_t_ - alpha_t * B_h * pred_res
522+
check_nan_jit(x_t_, "P x_t_ term", step)
523+
term2 = alpha_t * B_h * pred_res
524+
check_nan_jit(term2, "P term2", step)
525+
x_t = x_t_ - term2
485526
else: # Predict epsilon
486527
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
487-
x_t = x_t_ - sigma_t * B_h * pred_res
488-
check_nan_jit(x_t, "predictor x_t", step)
528+
check_nan_jit(x_t_, "P x_t_ term eps", step)
529+
term2 = sigma_t * B_h * pred_res
530+
check_nan_jit(term2, "P term2 eps", step)
531+
x_t = x_t_ - term2
532+
check_nan_jit(x_t, "P final x_t", step)
489533
return x_t.astype(x.dtype)
490534

491535
def multistep_uni_c_bh_update(

0 commit comments

Comments
 (0)