Skip to content

Commit 32a7035

Browse files
committed
Added more debug for NaNs
1 parent 836e94e commit 32a7035

1 file changed

Lines changed: 20 additions & 2 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def prepare_latents(
154154
lm=jnp.mean(condition[..., 1:]))
155155
jax.debug.print("condition latent std={std}", std=jnp.std(condition[..., 1:]))
156156

157-
return latents, condition, None
157+
return latents, condition, first_frame_mask
158158

159159

160160
def __call__(
@@ -213,6 +213,14 @@ def __call__(
213213
last_image=last_image_tensor,
214214
num_videos_per_prompt=num_videos_per_prompt,
215215
)
216+
if first_frame_mask is not None:
217+
jax.debug.print("FIRST FRAME MASK stats: min={mn}, max={mx}, mean={mean}, shape={shape}",
218+
mn=jnp.min(first_frame_mask),
219+
mx=jnp.max(first_frame_mask),
220+
mean=jnp.mean(first_frame_mask),
221+
shape=first_frame_mask.shape)
222+
else:
223+
jax.debug.print("FIRST FRAME MASK is None")
216224

217225
scheduler_state = self.scheduler.set_timesteps(
218226
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
@@ -341,6 +349,13 @@ def loop_body(step, vals):
341349
encoder_hidden_states_image=image_embeds_input,
342350
)
343351
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
352+
if step == 0:
353+
jax.debug.print("STEP 0: latents std={ls}, noise_pred std ={ns}, latents mean={lm}, noise mean={nm}",
354+
ls=jnp.std(latents),
355+
ns=jnp.std(noise_pred),
356+
lm=jnp.mean(latents),
357+
nm=jnp.mean(noise_pred))
358+
344359
jax.debug.print("Step {s}: noise_pred stats min={mn}, max={mx}, mean={mean}, std={std}",
345360
s=step,
346361
mn=jnp.min(noise_pred),
@@ -351,11 +366,14 @@ def loop_body(step, vals):
351366
s=step,
352367
std=jnp.std(latents),
353368
mean=jnp.mean(latents))
354-
jax.debug.print("first_frame_mask shape:", first_frame_mask.shape)
369+
jax.debug.print("first_frame_mask shape:", first_frame_mask.shape if first_frame_mask is not None else (-1,))
355370
jax.debug.print("first_frame_mask unique values:", jnp.unique(first_frame_mask))
356371
jax.debug.print("condition shape:", condition.shape)
357372
jax.debug.print("condition stats:", jnp.min(condition), jnp.max(condition), jnp.mean(condition))
358373
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
374+
if first_frame_mask is not None:
375+
clean_latents = condition[..., 4:]
376+
latents = first_frame_mask * clean_latents + (1.0 - first_frame_mask) * latents
359377
jax.debug.print("Step {s}: latents_next std={std}, mean={mean}",
360378
s=step,
361379
std=jnp.std(latents),

0 commit comments

Comments
 (0)