Skip to content

Commit 306ef81

Browse files
committed
added debug
1 parent a72339d commit 306ef81

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,6 @@ def prepare_latents(
107107

108108
num_channels_latents = self.vae.z_dim
109109
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
110-
jax.debug.print("num_frames: {nf}, num_latent_frames: {nlf}, expected: {exp}",
111-
nf=num_frames,
112-
nlf=latents.shape[1],
113-
exp=num_latent_frames)
114110
latent_height = height // self.vae_scale_factor_spatial
115111
latent_width = width // self.vae_scale_factor_spatial
116112

@@ -120,6 +116,10 @@ def prepare_latents(
120116
latents = randn_tensor(shape, rng, self.config, dtype)
121117
else:
122118
latents = latents.astype(dtype)
119+
jax.debug.print("num_frames: {nf}, num_latent_frames: {nlf}, expected: {exp}",
120+
nf=num_frames,
121+
nlf=latents.shape[1],
122+
exp=num_latent_frames)
123123
latent_condition, _ = self.prepare_latents_i2v_base(image, num_frames, dtype, last_image)
124124
mask_lat_size = jnp.ones((batch_size, 1, num_frames, latent_height, latent_width), dtype=dtype)
125125
if last_image is None:

0 commit comments

Comments
 (0)