Skip to content

Commit b23a2c1

Browse files
committed
epsilon clamping
1 parent 6c13bb9 commit b23a2c1

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def multistep_uni_p_bh_update(
387387
lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10)
388388
check_nan_jit(lambda_s0, "P lambda_s0", step)
389389

390-
h = lambda_t - lambda_s0
390+
h = jnp.clip(lambda_t - lambda_s0, -20.0, 20.0)
391391
check_nan_jit(h, "P h", step)
392392

393393
def rk_d1_loop_body(i, carry):
@@ -867,12 +867,16 @@ def add_noise(
867867
return add_noise_common(state.common, original_samples, noise, timesteps)
868868

869869
def _sigma_to_alpha_sigma_t(self, sigma):
870+
eps = 1e-10
870871
if self.config.use_flow_sigmas:
871872
alpha_t = 1 - sigma
872873
sigma_t = sigma
873874
else:
874-
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
875-
sigma_t = sigma * alpha_t
875+
sigma_clamped = jnp.maximum(sigma, eps)
876+
alpha_t = 1 / ((sigma_clamped**2 + 1) ** 0.5)
877+
sigma_t = sigma_clamped * alpha_t
878+
alpha_t = jnp.maximum(alpha_t, eps)
879+
sigma_t = jnp.maximum(sigma_t, eps)
876880

877881
return alpha_t, sigma_t
878882

0 commit comments

Comments
 (0)