Skip to content

Commit 288f8c5

Browse files
committed
Added more debug for NaNs
1 parent 46fd888 commit 288f8c5

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,10 @@ def print_step_zero_stats(operands):
372372
s=step,
373373
std=jnp.std(latents),
374374
mean=jnp.mean(latents))
375-
jax.debug.print("first_frame_mask shape:", first_frame_mask.shape if first_frame_mask is not None else (-1,))
376-
jax.debug.print("first_frame_mask unique values:", jnp.unique(first_frame_mask))
377-
jax.debug.print("condition shape:", condition.shape)
378-
jax.debug.print("condition stats:", jnp.min(condition), jnp.max(condition), jnp.mean(condition))
375+
jax.debug.print("first_frame_mask shape: {}", first_frame_mask.shape if first_frame_mask is not None else (-1,))
376+
jax.debug.print("first_frame_mask unique values: {}", jnp.unique(first_frame_mask))
377+
jax.debug.print("condition shape: {}", condition.shape)
378+
jax.debug.print("condition stats: {}, {}, {}", jnp.min(condition), jnp.max(condition), jnp.mean(condition))
379379
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
380380
if first_frame_mask is not None:
381381
clean_latents = condition[..., 4:]

0 commit comments

Comments
 (0)