@@ -107,6 +107,10 @@ def prepare_latents(
107107
108108 num_channels_latents = self .vae .z_dim
109109 num_latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
110+ jax .debug .print ("num_frames: {nf}, num_latent_frames: {nlf}, expected: {exp}" ,
111+ nf = num_frames ,
112+ nlf = latents .shape [1 ],
113+ exp = num_latent_frames )
110114 latent_height = height // self .vae_scale_factor_spatial
111115 latent_width = width // self .vae_scale_factor_spatial
112116
@@ -124,6 +128,13 @@ def prepare_latents(
124128 mask_lat_size = mask_lat_size .at [:, :, 1 :- 1 , :, :].set (0 )
125129 first_frame_mask = mask_lat_size [:, :, 0 :1 ]
126130 first_frame_mask = jnp .repeat (first_frame_mask , self .vae_scale_factor_temporal , axis = 2 )
131+ jax .debug .print ("first_frame_mask.shape:{shape}, is None:{isnone}" ,
132+ shape = first_frame_mask .shape if first_frame_mask is not None else (- 1 ,),
133+ isnone = first_frame_mask is None )
134+ jax .debug .print ("first_frame_mask_stats: min={mn:.2f}, max={mx:.2f}, mean={mean:.2f}" ,
135+ mn = jnp .min (first_frame_mask ) if first_frame_mask is not None else 0.0 ,
136+ mx = jnp .max (first_frame_mask ) if first_frame_mask is not None else 0.0 ,
137+ mean = jnp .mean (first_frame_mask ) if first_frame_mask is not None else 0.0 )
127138 mask_lat_size = jnp .concatenate ([first_frame_mask , mask_lat_size [:, :, 1 :]], axis = 2 )
128139 mask_lat_size = mask_lat_size .reshape (
129140 batch_size ,
@@ -135,6 +146,12 @@ def prepare_latents(
135146 )
136147 mask_lat_size = jnp .transpose (mask_lat_size , (0 , 2 , 4 , 5 , 3 , 1 )).squeeze (- 1 )
137148 condition = jnp .concatenate ([mask_lat_size , latent_condition ], axis = - 1 )
149+ jax .debug .print ("condition shape: {shape}, channel dim: {c}" ,
150+ shape = condition .shape ,
151+ c = condition .shape [- 1 ])
152+ jax .debug .print ("condition stats: mask_mean={mm:.4f}, latent_mean={lm:.4f}" ,
153+ mm = jnp .mean (condition [..., 0 ]),
154+ lm = jnp .mean (condition [..., 1 :]))
138155
139156 return latents , condition , None
140157
@@ -300,11 +317,24 @@ def loop_body(step, vals):
300317 encoder_hidden_states_image = image_embeds_input ,
301318 )
302319 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
320+ jax .debug .print ("Step {s}: latents_prev std={std:.6f}, mean={mean:.6f}" ,
321+ s = step ,
322+ std = jnp .std (latents ),
323+ mean = jnp .mean (latents ))
303324 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
325+ jax .debug .print ("Step {s}: latents_next std={std:.6f}, mean={mean:.6f}" ,
326+ s = step ,
327+ std = jnp .std (latents ),
328+ mean = jnp .mean (latents ))
304329 latents = latents .astype (original_dtype )
305330 return latents , scheduler_state , rng
306331
307332 max_logging .log (f"Running fori_loop for { num_inference_steps } steps." )
308333 latents , _ , _ = jax .lax .fori_loop (0 , num_inference_steps , loop_body , (latents , scheduler_state , rng ))
334+ jax .debug .print ("Final latents states: min={lmin:.6f}, max={lmax:.6f}, mean={lmean:.6f}, std={lstd:.6f}" ,
335+ lmin = jnp .min (latents ),
336+ lmax = jnp .max (latents ),
337+ lmean = jnp .mean (latents ),
338+ lstd = jnp .std (latents ))
309339 max_logging .log ("Finished fori_loop." )
310340 return latents
0 commit comments