Skip to content

Commit 7572e09

Browse files
committed
debug added in wan 2.1 i2v pipeline
1 parent ed0655d commit 7572e09

1 file changed

Lines changed: 24 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,16 @@ def __call__(
256256
max_logging.log(f"[DEBUG CALL] Decoded video type: {type(decoded_video)}")
257257
return decoded_video
258258

259+
def check_nan_jit(tensor: jax.Array, name: str, step: jax.Array):
260+
if tensor is None:
261+
return
262+
263+
has_nans = jnp.isnan(tensor).any()
264+
has_infs = jnp.isinf(tensor).any()
265+
jax.debug.print(f"[DEBUG JIT {jax.process_index()}] Step: {{step}} - {name}: "
266+
"Shape: {shape}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}",
267+
step=step, shape=tensor.shape, has_nans_val=has_nans, has_infs_val=has_infs)
268+
259269
def run_inference_2_1_i2v(
260270
graphdef, sharded_state, rest_of_state,
261271
latents: jnp.array,
@@ -289,16 +299,24 @@ def loop_body(step, vals):
289299
rng, timestep_rng = jax.random.split(rng)
290300
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
291301

302+
check_nan_jit(latents, "latents_prev at loop start", step)
303+
292304
latents_input = latents
293305
if do_classifier_free_guidance:
294306
latents_input = jnp.concatenate([latents, latents], axis=0)
307+
check_nan_jit(latents_input, "latents_input after CFG concat", step)
295308

296309
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
310+
check_nan_jit(latent_model_input, "latent_model_input after cond concat", step)
297311
timestep = jnp.broadcast_to(t, latents_input.shape[0])
298312
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
313+
check_nan_jit(latent_model_input, "latent_model_input for transformer", step)
299314

300315
prompt_embeds_input = prompt_embeds
301316
image_embeds_input = image_embeds
317+
check_nan_jit(prompt_embeds_input, "prompt_embeds_input for transformer", step)
318+
check_nan_jit(image_embeds_input, "image_embeds_input for transformer", step)
319+
302320

303321
noise_pred, _ = transformer_forward_pass(
304322
graphdef, sharded_state, rest_of_state,
@@ -307,11 +325,17 @@ def loop_body(step, vals):
307325
guidance_scale=guidance_scale,
308326
encoder_hidden_states_image=image_embeds_input,
309327
)
328+
check_nan_jit(noise_pred, "noise_pred_bcthw from transformer", step)
310329
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
330+
check_nan_jit(noise_pred, "noise_pred after transpose", step)
311331

312332
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
333+
check_nan_jit(latents, "latents_next after scheduler", step)
313334
latents = latents.astype(original_dtype)
335+
check_nan_jit(latents, "latents_next after dtype cast", step)
314336
return latents, scheduler_state, rng
315337

338+
max_logging.log(f"Running fori_loop for {num_inference_steps} steps.")
316339
latents, _, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state, rng))
340+
max_logging.log("Finished fori_loop.")
317341
return latents

0 commit comments

Comments
 (0)