Skip to content

Commit 390e113

Browse files
committed
changed expand_timesteps logic
1 parent bc6955d commit 390e113

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def __call__(
255255
)
256256
if self.config.expand_timesteps:
257257
jax.debug.print("Applying first frame preservation with expand_timesteps.")
258-
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
258+
clean_latents = condition[..., 4:]
259+
latents = first_frame_mask * clean_latents + (1 - first_frame_mask) * latents
259260
latents_bcthw = jnp.transpose(latents, (0, 4, 1, 2, 3))
260261
latents_denorm_bcthw = self._denormalize_latents(latents_bcthw)
261262

0 commit comments

Comments
 (0)