Skip to content

Commit 444c106

Browse files
committed
removed debug from i2v 2.1
1 parent 68967f8 commit 444c106

1 file changed

Lines changed: 0 additions & 21 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -257,15 +257,6 @@ def __call__(
257257
max_logging.log(f"[DEBUG CALL] Decoded video type: {type(decoded_video)}")
258258
return decoded_video
259259

260-
def check_nan_jit(tensor: jax.Array, name: str, step: jax.Array):
261-
if tensor is None:
262-
return
263-
264-
has_nans = jnp.isnan(tensor).any()
265-
has_infs = jnp.isinf(tensor).any()
266-
jax.debug.print(f"[DEBUG JIT {jax.process_index()}] Step: {{step}} - {name}: "
267-
"Shape: {shape}, Has NaNs: {has_nans_val}, Has Infs: {has_infs_val}",
268-
step=step, shape=tensor.shape, has_nans_val=has_nans, has_infs_val=has_infs)
269260

270261
def run_inference_2_1_i2v(
271262
graphdef, sharded_state, rest_of_state,
@@ -300,23 +291,16 @@ def loop_body(step, vals):
300291
rng, timestep_rng = jax.random.split(rng)
301292
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
302293

303-
check_nan_jit(latents, "latents_prev at loop start", step)
304-
305294
latents_input = latents
306295
if do_classifier_free_guidance:
307296
latents_input = jnp.concatenate([latents, latents], axis=0)
308-
check_nan_jit(latents_input, "latents_input after CFG concat", step)
309297

310298
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
311-
check_nan_jit(latent_model_input, "latent_model_input after cond concat", step)
312299
timestep = jnp.broadcast_to(t, latents_input.shape[0])
313300
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
314-
check_nan_jit(latent_model_input, "latent_model_input for transformer", step)
315301

316302
prompt_embeds_input = prompt_embeds
317303
image_embeds_input = image_embeds
318-
check_nan_jit(prompt_embeds_input, "prompt_embeds_input for transformer", step)
319-
check_nan_jit(image_embeds_input, "image_embeds_input for transformer", step)
320304

321305

322306
noise_pred, _ = transformer_forward_pass(
@@ -326,14 +310,9 @@ def loop_body(step, vals):
326310
guidance_scale=guidance_scale,
327311
encoder_hidden_states_image=image_embeds_input,
328312
)
329-
check_nan_jit(noise_pred, "noise_pred_bcthw from transformer", step)
330313
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
331-
check_nan_jit(noise_pred, "noise_pred after transpose", step)
332-
333314
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
334-
check_nan_jit(latents, "latents_next after scheduler", step)
335315
latents = latents.astype(original_dtype)
336-
check_nan_jit(latents, "latents_next after dtype cast", step)
337316
return latents, scheduler_state, rng
338317

339318
max_logging.log(f"Running fori_loop for {num_inference_steps} steps.")

0 commit comments

Comments
 (0)