Skip to content

Commit 836e94e

Browse files
committed
more debug added
1 parent b69d05f commit 836e94e

1 file changed

Lines changed: 27 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def prepare_latents(
152152
jax.debug.print("condition stats: mask_mean={mm}, latent_mean={lm}",
153153
mm=jnp.mean(condition[..., 0]),
154154
lm=jnp.mean(condition[..., 1:]))
155+
jax.debug.print("condition latent std={std}", std=jnp.std(condition[..., 1:]))
155156

156157
return latents, condition, None
157158

@@ -303,15 +304,31 @@ def loop_body(step, vals):
303304
original_dtype = latents.dtype
304305
rng, timestep_rng = jax.random.split(rng)
305306
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
307+
jax.debug.print("Step {s}: timestep={t}", s=step, t=t)
306308

307309
latents_input = latents
308310
if do_classifier_free_guidance:
309311
latents_input = jnp.concatenate([latents, latents], axis=0)
312+
jax.debug.print("Step{s}: latents_input stats min={mn}, max={mx}, mean={mean}, std={std}",
313+
s=step,
314+
mn=jnp.min(latents_input),
315+
mx=jnp.max(latents_input),
316+
mean=jnp.mean(latents_input),
317+
std=jnp.std(latents_input))
310318

311319
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
312320
timestep = jnp.broadcast_to(t, latents_input.shape[0])
313321
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
314322

323+
jax.debug.print("Step {s}: latent_model_input shape: {shape}",
324+
s=step,
325+
shape=latent_model_input.shape)
326+
327+
channel_energy = jnp.sum(latent_model_input*latent_model_input,axis=(0,2,3,4))
328+
jax.debug.print("Step {s}: channel energy first 10={ce}",
329+
s=step,
330+
ce=channel_energy[:10])
331+
315332
prompt_embeds_input = prompt_embeds
316333
image_embeds_input = image_embeds
317334

@@ -324,10 +341,20 @@ def loop_body(step, vals):
324341
encoder_hidden_states_image=image_embeds_input,
325342
)
326343
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
344+
jax.debug.print("Step {s}: noise_pred stats min={mn}, max={mx}, mean={mean}, std={std}",
345+
s=step,
346+
mn=jnp.min(noise_pred),
347+
mx=jnp.max(noise_pred),
348+
mean=jnp.mean(noise_pred),
349+
std=jnp.std(noise_pred))
327350
jax.debug.print("Step {s}: latents_prev std={std}, mean={mean}",
328351
s=step,
329352
std=jnp.std(latents),
330353
mean=jnp.mean(latents))
354+
jax.debug.print("first_frame_mask shape:", first_frame_mask.shape)
355+
jax.debug.print("first_frame_mask unique values:", jnp.unique(first_frame_mask))
356+
jax.debug.print("condition shape:", condition.shape)
357+
jax.debug.print("condition stats:", jnp.min(condition), jnp.max(condition), jnp.mean(condition))
331358
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
332359
jax.debug.print("Step {s}: latents_next std={std}, mean={mean}",
333360
s=step,

0 commit comments

Comments
 (0)