@@ -152,6 +152,7 @@ def prepare_latents(
152152 jax .debug .print ("condition stats: mask_mean={mm}, latent_mean={lm}" ,
153153 mm = jnp .mean (condition [..., 0 ]),
154154 lm = jnp .mean (condition [..., 1 :]))
155+ jax .debug .print ("condition latent std={std}" , std = jnp .std (condition [..., 1 :]))
155156
156157 return latents , condition , None
157158
@@ -303,15 +304,31 @@ def loop_body(step, vals):
303304 original_dtype = latents .dtype
304305 rng , timestep_rng = jax .random .split (rng )
305306 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
307+ jax .debug .print ("Step {s}: timestep={t}" , s = step , t = t )
306308
307309 latents_input = latents
308310 if do_classifier_free_guidance :
309311 latents_input = jnp .concatenate ([latents , latents ], axis = 0 )
312+ jax .debug .print ("Step{s}: latents_input stats min={mn}, max={mx}, mean={mean}, std={std}" ,
313+ s = step ,
314+ mn = jnp .min (latents_input ),
315+ mx = jnp .max (latents_input ),
316+ mean = jnp .mean (latents_input ),
317+ std = jnp .std (latents_input ))
310318
311319 latent_model_input = jnp .concatenate ([latents_input , condition ], axis = - 1 )
312320 timestep = jnp .broadcast_to (t , latents_input .shape [0 ])
313321 latent_model_input = jnp .transpose (latent_model_input , (0 , 4 , 1 , 2 , 3 ))
314322
323+ jax .debug .print ("Step {s}: latent_model_input shape: {shape}" ,
324+ s = step ,
325+ shape = latent_model_input .shape )
326+
327+ channel_energy = jnp .sum (latent_model_input * latent_model_input ,axis = (0 ,2 ,3 ,4 ))
328+ jax .debug .print ("Step {s}: channel energy first 10={ce}" ,
329+ s = step ,
330+ ce = channel_energy [:10 ])
331+
315332 prompt_embeds_input = prompt_embeds
316333 image_embeds_input = image_embeds
317334
@@ -324,10 +341,20 @@ def loop_body(step, vals):
324341 encoder_hidden_states_image = image_embeds_input ,
325342 )
326343 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
344+ jax .debug .print ("Step {s}: noise_pred stats min={mn}, max={mx}, mean={mean}, std={std}" ,
345+ s = step ,
346+ mn = jnp .min (noise_pred ),
347+ mx = jnp .max (noise_pred ),
348+ mean = jnp .mean (noise_pred ),
349+ std = jnp .std (noise_pred ))
327350 jax .debug .print ("Step {s}: latents_prev std={std}, mean={mean}" ,
328351 s = step ,
329352 std = jnp .std (latents ),
330353 mean = jnp .mean (latents ))
354+ jax .debug .print ("first_frame_mask shape:" , first_frame_mask .shape )
355+ jax .debug .print ("first_frame_mask unique values:" , jnp .unique (first_frame_mask ))
356+ jax .debug .print ("condition shape:" , condition .shape )
357+ jax .debug .print ("condition stats:" , jnp .min (condition ), jnp .max (condition ), jnp .mean (condition ))
331358 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
332359 jax .debug .print ("Step {s}: latents_next std={std}, mean={mean}" ,
333360 s = step ,
0 commit comments