@@ -131,7 +131,7 @@ def prepare_latents(
131131 jax .debug .print ("first_frame_mask.shape:{shape}, is None:{isnone}" ,
132132 shape = first_frame_mask .shape if first_frame_mask is not None else (- 1 ,),
133133 isnone = first_frame_mask is None )
134- jax .debug .print ("first_frame_mask_stats: min={mn:.2f }, max={mx:.2f }, mean={mean:.2f }" ,
134+ jax .debug .print ("first_frame_mask_stats: min={mn}, max={mx}, mean={mean}" ,
135135 mn = jnp .min (first_frame_mask ) if first_frame_mask is not None else 0.0 ,
136136 mx = jnp .max (first_frame_mask ) if first_frame_mask is not None else 0.0 ,
137137 mean = jnp .mean (first_frame_mask ) if first_frame_mask is not None else 0.0 )
@@ -149,7 +149,7 @@ def prepare_latents(
149149 jax .debug .print ("condition shape: {shape}, channel dim: {c}" ,
150150 shape = condition .shape ,
151151 c = condition .shape [- 1 ])
152- jax .debug .print ("condition stats: mask_mean={mm:.4f }, latent_mean={lm:.4f }" ,
152+ jax .debug .print ("condition stats: mask_mean={mm}, latent_mean={lm}" ,
153153 mm = jnp .mean (condition [..., 0 ]),
154154 lm = jnp .mean (condition [..., 1 :]))
155155
@@ -317,12 +317,12 @@ def loop_body(step, vals):
317317 encoder_hidden_states_image = image_embeds_input ,
318318 )
319319 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
320- jax .debug .print ("Step {s}: latents_prev std={std:.6f }, mean={mean:.6f }" ,
320+ jax .debug .print ("Step {s}: latents_prev std={std}, mean={mean}" ,
321321 s = step ,
322322 std = jnp .std (latents ),
323323 mean = jnp .mean (latents ))
324324 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
325- jax .debug .print ("Step {s}: latents_next std={std:.6f }, mean={mean:.6f }" ,
325+ jax .debug .print ("Step {s}: latents_next std={std}, mean={mean}" ,
326326 s = step ,
327327 std = jnp .std (latents ),
328328 mean = jnp .mean (latents ))
@@ -331,7 +331,7 @@ def loop_body(step, vals):
331331
332332 max_logging .log (f"Running fori_loop for { num_inference_steps } steps." )
333333 latents , _ , _ = jax .lax .fori_loop (0 , num_inference_steps , loop_body , (latents , scheduler_state , rng ))
334- jax .debug .print ("Final latents states: min={lmin:.6f }, max={lmax:.6f }, mean={lmean:.6f }, std={lstd:.6f }" ,
334+ jax .debug .print ("Final latents states: min={lmin}, max={lmax}, mean={lmean}, std={lstd}" ,
335335 lmin = jnp .min (latents ),
336336 lmax = jnp .max (latents ),
337337 lmean = jnp .mean (latents ),
0 commit comments