@@ -152,9 +152,8 @@ def prepare_latents(
152152 jax .debug .print ("condition stats: mask_mean={mm}, latent_mean={lm}" ,
153153 mm = jnp .mean (condition [..., 0 ]),
154154 lm = jnp .mean (condition [..., 1 :]))
155- jax .debug .print ("condition latent std={std}" , std = jnp .std (condition [..., 1 :]))
156155
157- return latents , condition , first_frame_mask
156+ return latents , condition , None
158157
159158
160159 def __call__ (
@@ -213,14 +212,6 @@ def __call__(
213212 last_image = last_image_tensor ,
214213 num_videos_per_prompt = num_videos_per_prompt ,
215214 )
216- if first_frame_mask is not None :
217- jax .debug .print ("FIRST FRAME MASK stats: min={mn}, max={mx}, mean={mean}, shape={shape}" ,
218- mn = jnp .min (first_frame_mask ),
219- mx = jnp .max (first_frame_mask ),
220- mean = jnp .mean (first_frame_mask ),
221- shape = first_frame_mask .shape )
222- else :
223- jax .debug .print ("FIRST FRAME MASK is None" )
224215
225216 scheduler_state = self .scheduler .set_timesteps (
226217 self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
@@ -312,31 +303,15 @@ def loop_body(step, vals):
312303 original_dtype = latents .dtype
313304 rng , timestep_rng = jax .random .split (rng )
314305 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
315- jax .debug .print ("Step {s}: timestep={t}" , s = step , t = t )
316306
317307 latents_input = latents
318308 if do_classifier_free_guidance :
319309 latents_input = jnp .concatenate ([latents , latents ], axis = 0 )
320- jax .debug .print ("Step{s}: latents_input stats min={mn}, max={mx}, mean={mean}, std={std}" ,
321- s = step ,
322- mn = jnp .min (latents_input ),
323- mx = jnp .max (latents_input ),
324- mean = jnp .mean (latents_input ),
325- std = jnp .std (latents_input ))
326310
327311 latent_model_input = jnp .concatenate ([latents_input , condition ], axis = - 1 )
328312 timestep = jnp .broadcast_to (t , latents_input .shape [0 ])
329313 latent_model_input = jnp .transpose (latent_model_input , (0 , 4 , 1 , 2 , 3 ))
330314
331- jax .debug .print ("Step {s}: latent_model_input shape: {shape}" ,
332- s = step ,
333- shape = latent_model_input .shape )
334-
335- channel_energy = jnp .sum (latent_model_input * latent_model_input ,axis = (0 ,2 ,3 ,4 ))
336- jax .debug .print ("Step {s}: channel energy first 10={ce}" ,
337- s = step ,
338- ce = channel_energy [:10 ])
339-
340315 prompt_embeds_input = prompt_embeds
341316 image_embeds_input = image_embeds
342317
@@ -349,36 +324,11 @@ def loop_body(step, vals):
349324 encoder_hidden_states_image = image_embeds_input ,
350325 )
351326 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
352- def print_step_zero_stats (operands ):
353- l , np_pred = operands
354- jax .debug .print ("STEP 0: latents std={ls}, noise_pred std ={ns}, latents mean={lm}, noise mean={nm}" ,
355- ls = jnp .std (l ),
356- ns = jnp .std (np_pred ),
357- lm = jnp .mean (l ),
358- nm = jnp .mean (np_pred ))
359- jax .lax .cond (
360- step == 0 ,
361- print_step_zero_stats ,
362- lambda _ : None , # Do nothing if step != 0
363- (latents , noise_pred )
364- )
365- jax .debug .print ("Step {s}: noise_pred stats min={mn}, max={mx}, mean={mean}, std={std}" ,
366- s = step ,
367- mn = jnp .min (noise_pred ),
368- mx = jnp .max (noise_pred ),
369- mean = jnp .mean (noise_pred ),
370- std = jnp .std (noise_pred ))
371327 jax .debug .print ("Step {s}: latents_prev std={std}, mean={mean}" ,
372328 s = step ,
373329 std = jnp .std (latents ),
374330 mean = jnp .mean (latents ))
375- jax .debug .print ("first_frame_mask shape: {}" , first_frame_mask .shape if first_frame_mask is not None else (- 1 ,))
376- jax .debug .print ("condition shape: {}" , condition .shape )
377- jax .debug .print ("condition stats: {}, {}, {}" , jnp .min (condition ), jnp .max (condition ), jnp .mean (condition ))
378331 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
379- if first_frame_mask is not None :
380- clean_latents = condition [..., 4 :]
381- latents = first_frame_mask * clean_latents + (1.0 - first_frame_mask ) * latents
382332 jax .debug .print ("Step {s}: latents_next std={std}, mean={mean}" ,
383333 s = step ,
384334 std = jnp .std (latents ),
@@ -394,4 +344,4 @@ def print_step_zero_stats(operands):
394344 lmean = jnp .mean (latents ),
395345 lstd = jnp .std (latents ))
396346 max_logging .log ("Finished fori_loop." )
397- return latents
347+ return latents
0 commit comments