Skip to content

Commit 46fd888

Browse files
committed
Added more debug for NaNs
1 parent 32a7035 commit 46fd888

1 file changed

Lines changed: 13 additions & 7 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,19 @@ def loop_body(step, vals):
349349
encoder_hidden_states_image=image_embeds_input,
350350
)
351351
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
352-
if step == 0:
353-
jax.debug.print("STEP 0: latents std={ls}, noise_pred std ={ns}, latents mean={lm}, noise mean={nm}",
354-
ls=jnp.std(latents),
355-
ns=jnp.std(noise_pred),
356-
lm=jnp.mean(latents),
357-
nm=jnp.mean(noise_pred))
358-
352+
def print_step_zero_stats(operands):
353+
l, np_pred = operands
354+
jax.debug.print("STEP 0: latents std={ls}, noise_pred std ={ns}, latents mean={lm}, noise mean={nm}",
355+
ls=jnp.std(l),
356+
ns=jnp.std(np_pred),
357+
lm=jnp.mean(l),
358+
nm=jnp.mean(np_pred))
359+
jax.lax.cond(
360+
step == 0,
361+
print_step_zero_stats,
362+
lambda _: None, # Do nothing if step != 0
363+
(latents, noise_pred)
364+
)
359365
jax.debug.print("Step {s}: noise_pred stats min={mn}, max={mx}, mean={mean}, std={std}",
360366
s=step,
361367
mn=jnp.min(noise_pred),

0 commit comments

Comments
 (0)