@@ -154,7 +154,7 @@ def prepare_latents(
154154 lm = jnp .mean (condition [..., 1 :]))
155155 jax .debug .print ("condition latent std={std}" , std = jnp .std (condition [..., 1 :]))
156156
157- return latents , condition , None
157+ return latents , condition , first_frame_mask
158158
159159
160160 def __call__ (
@@ -213,6 +213,14 @@ def __call__(
213213 last_image = last_image_tensor ,
214214 num_videos_per_prompt = num_videos_per_prompt ,
215215 )
216+ if first_frame_mask is not None :
217+ jax .debug .print ("FIRST FRAME MASK stats: min={mn}, max={mx}, mean={mean}, shape={shape}" ,
218+ mn = jnp .min (first_frame_mask ),
219+ mx = jnp .max (first_frame_mask ),
220+ mean = jnp .mean (first_frame_mask ),
221+ shape = first_frame_mask .shape )
222+ else :
223+ jax .debug .print ("FIRST FRAME MASK is None" )
216224
217225 scheduler_state = self .scheduler .set_timesteps (
218226 self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
@@ -341,6 +349,13 @@ def loop_body(step, vals):
341349 encoder_hidden_states_image = image_embeds_input ,
342350 )
343351 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
352+ if step == 0 :
353+ jax .debug .print ("STEP 0: latents std={ls}, noise_pred std ={ns}, latents mean={lm}, noise mean={nm}" ,
354+ ls = jnp .std (latents ),
355+ ns = jnp .std (noise_pred ),
356+ lm = jnp .mean (latents ),
357+ nm = jnp .mean (noise_pred ))
358+
344359 jax .debug .print ("Step {s}: noise_pred stats min={mn}, max={mx}, mean={mean}, std={std}" ,
345360 s = step ,
346361 mn = jnp .min (noise_pred ),
@@ -351,11 +366,14 @@ def loop_body(step, vals):
351366 s = step ,
352367 std = jnp .std (latents ),
353368 mean = jnp .mean (latents ))
354- jax .debug .print ("first_frame_mask shape:" , first_frame_mask .shape )
369+ jax .debug .print ("first_frame_mask shape:" , first_frame_mask .shape if first_frame_mask is not None else ( - 1 ,) )
355370 jax .debug .print ("first_frame_mask unique values:" , jnp .unique (first_frame_mask ))
356371 jax .debug .print ("condition shape:" , condition .shape )
357372 jax .debug .print ("condition stats:" , jnp .min (condition ), jnp .max (condition ), jnp .mean (condition ))
358373 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
374+ if first_frame_mask is not None :
375+ clean_latents = condition [..., 4 :]
376+ latents = first_frame_mask * clean_latents + (1.0 - first_frame_mask ) * latents
359377 jax .debug .print ("Step {s}: latents_next std={std}, mean={mean}" ,
360378 s = step ,
361379 std = jnp .std (latents ),
0 commit comments