Skip to content

Commit 68fca3a

Browse files
committed
added debug
1 parent 306ef81 commit 68fca3a

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def prepare_latents(
131131
jax.debug.print("first_frame_mask.shape:{shape}, is None:{isnone}",
132132
shape = first_frame_mask.shape if first_frame_mask is not None else (-1,),
133133
isnone = first_frame_mask is None)
134-
jax.debug.print("first_frame_mask_stats: min={mn:.2f}, max={mx:.2f}, mean={mean:.2f}",
134+
jax.debug.print("first_frame_mask_stats: min={mn}, max={mx}, mean={mean}",
135135
mn=jnp.min(first_frame_mask) if first_frame_mask is not None else 0.0,
136136
mx=jnp.max(first_frame_mask) if first_frame_mask is not None else 0.0,
137137
mean=jnp.mean(first_frame_mask) if first_frame_mask is not None else 0.0)
@@ -149,7 +149,7 @@ def prepare_latents(
149149
jax.debug.print("condition shape: {shape}, channel dim: {c}",
150150
shape=condition.shape,
151151
c=condition.shape[-1])
152-
jax.debug.print("condition stats: mask_mean={mm:.4f}, latent_mean={lm:.4f}",
152+
jax.debug.print("condition stats: mask_mean={mm}, latent_mean={lm}",
153153
mm=jnp.mean(condition[..., 0]),
154154
lm=jnp.mean(condition[..., 1:]))
155155

@@ -317,12 +317,12 @@ def loop_body(step, vals):
317317
encoder_hidden_states_image=image_embeds_input,
318318
)
319319
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
320-
jax.debug.print("Step {s}: latents_prev std={std:.6f}, mean={mean:.6f}",
320+
jax.debug.print("Step {s}: latents_prev std={std}, mean={mean}",
321321
s=step,
322322
std=jnp.std(latents),
323323
mean=jnp.mean(latents))
324324
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
325-
jax.debug.print("Step {s}: latents_next std={std:.6f}, mean={mean:.6f}",
325+
jax.debug.print("Step {s}: latents_next std={std}, mean={mean}",
326326
s=step,
327327
std=jnp.std(latents),
328328
mean=jnp.mean(latents))
@@ -331,7 +331,7 @@ def loop_body(step, vals):
331331

332332
max_logging.log(f"Running fori_loop for {num_inference_steps} steps.")
333333
latents, _, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state, rng))
334-
jax.debug.print("Final latents states: min={lmin:.6f}, max={lmax:.6f}, mean={lmean:.6f}, std={lstd:.6f}",
334+
jax.debug.print("Final latents states: min={lmin}, max={lmax}, mean={lmean}, std={lstd}",
335335
lmin=jnp.min(latents),
336336
lmax=jnp.max(latents),
337337
lmean=jnp.mean(latents),

0 commit comments

Comments
 (0)