Skip to content

Commit 6c13bb9

Browse files
committed
some changes reverted
1 parent e160f84 commit 6c13bb9

3 files changed

Lines changed: 5 additions & 6 deletions

File tree

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,10 @@ timestep_bias: {
113113

114114
# Override parameters from checkpoints's scheduler.
115115
diffusion_scheduler_config: {
116-
_class_name: 'FlaxUniPCMultistepScheduler',
116+
_class_name: 'FlaxEulerDiscreteScheduler',
117117
prediction_type: 'epsilon',
118118
rescale_zero_terminal_snr: False,
119-
timestep_spacing: 'trailing',
120-
final_sigmas_type: 'sigma_min'
119+
timestep_spacing: 'trailing'
121120
}
122121

123122
# Output directory

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def __call__(
189189
height=height,
190190
width=width,
191191
num_frames=num_frames,
192-
dtype=jnp.float32,
192+
dtype=image_embeds.dtype,
193193
rng=latents_rng,
194194
latents=latents,
195195
last_image=last_image_tensor,

src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,9 @@ def multistep_uni_p_bh_update(
382382
check_nan_jit(alpha_s0, "P alpha_s0", step)
383383
check_nan_jit(sigma_s0, "P sigma_s0", step)
384384

385-
lambda_t = jnp.log(alpha_t + 1e-5) - jnp.log(sigma_t + 1e-5)
385+
lambda_t = jnp.log(alpha_t + 1e-10) - jnp.log(sigma_t + 1e-10)
386386
check_nan_jit(lambda_t, "P lambda_t", step)
387-
lambda_s0 = jnp.log(alpha_s0 + 1e-5) - jnp.log(sigma_s0 + 1e-5)
387+
lambda_s0 = jnp.log(alpha_s0 + 1e-10) - jnp.log(sigma_s0 + 1e-10)
388388
check_nan_jit(lambda_s0, "P lambda_s0", step)
389389

390390
h = lambda_t - lambda_s0

0 commit comments

Comments
 (0)