Skip to content

Commit af28002

Browse files
committed
reverted
1 parent 082f1f9 commit af28002

1 file changed

Lines changed: 2 additions & 52 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 2 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,8 @@ def prepare_latents(
152152
jax.debug.print("condition stats: mask_mean={mm}, latent_mean={lm}",
153153
mm=jnp.mean(condition[..., 0]),
154154
lm=jnp.mean(condition[..., 1:]))
155-
jax.debug.print("condition latent std={std}", std=jnp.std(condition[..., 1:]))
156155

157-
return latents, condition, first_frame_mask
156+
return latents, condition, None
158157

159158

160159
def __call__(
@@ -213,14 +212,6 @@ def __call__(
213212
last_image=last_image_tensor,
214213
num_videos_per_prompt=num_videos_per_prompt,
215214
)
216-
if first_frame_mask is not None:
217-
jax.debug.print("FIRST FRAME MASK stats: min={mn}, max={mx}, mean={mean}, shape={shape}",
218-
mn=jnp.min(first_frame_mask),
219-
mx=jnp.max(first_frame_mask),
220-
mean=jnp.mean(first_frame_mask),
221-
shape=first_frame_mask.shape)
222-
else:
223-
jax.debug.print("FIRST FRAME MASK is None")
224215

225216
scheduler_state = self.scheduler.set_timesteps(
226217
self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape
@@ -312,31 +303,15 @@ def loop_body(step, vals):
312303
original_dtype = latents.dtype
313304
rng, timestep_rng = jax.random.split(rng)
314305
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
315-
jax.debug.print("Step {s}: timestep={t}", s=step, t=t)
316306

317307
latents_input = latents
318308
if do_classifier_free_guidance:
319309
latents_input = jnp.concatenate([latents, latents], axis=0)
320-
jax.debug.print("Step{s}: latents_input stats min={mn}, max={mx}, mean={mean}, std={std}",
321-
s=step,
322-
mn=jnp.min(latents_input),
323-
mx=jnp.max(latents_input),
324-
mean=jnp.mean(latents_input),
325-
std=jnp.std(latents_input))
326310

327311
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
328312
timestep = jnp.broadcast_to(t, latents_input.shape[0])
329313
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
330314

331-
jax.debug.print("Step {s}: latent_model_input shape: {shape}",
332-
s=step,
333-
shape=latent_model_input.shape)
334-
335-
channel_energy = jnp.sum(latent_model_input*latent_model_input,axis=(0,2,3,4))
336-
jax.debug.print("Step {s}: channel energy first 10={ce}",
337-
s=step,
338-
ce=channel_energy[:10])
339-
340315
prompt_embeds_input = prompt_embeds
341316
image_embeds_input = image_embeds
342317

@@ -349,36 +324,11 @@ def loop_body(step, vals):
349324
encoder_hidden_states_image=image_embeds_input,
350325
)
351326
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1))
352-
def print_step_zero_stats(operands):
353-
l, np_pred = operands
354-
jax.debug.print("STEP 0: latents std={ls}, noise_pred std ={ns}, latents mean={lm}, noise mean={nm}",
355-
ls=jnp.std(l),
356-
ns=jnp.std(np_pred),
357-
lm=jnp.mean(l),
358-
nm=jnp.mean(np_pred))
359-
jax.lax.cond(
360-
step == 0,
361-
print_step_zero_stats,
362-
lambda _: None, # Do nothing if step != 0
363-
(latents, noise_pred)
364-
)
365-
jax.debug.print("Step {s}: noise_pred stats min={mn}, max={mx}, mean={mean}, std={std}",
366-
s=step,
367-
mn=jnp.min(noise_pred),
368-
mx=jnp.max(noise_pred),
369-
mean=jnp.mean(noise_pred),
370-
std=jnp.std(noise_pred))
371327
jax.debug.print("Step {s}: latents_prev std={std}, mean={mean}",
372328
s=step,
373329
std=jnp.std(latents),
374330
mean=jnp.mean(latents))
375-
jax.debug.print("first_frame_mask shape: {}", first_frame_mask.shape if first_frame_mask is not None else (-1,))
376-
jax.debug.print("condition shape: {}", condition.shape)
377-
jax.debug.print("condition stats: {}, {}, {}", jnp.min(condition), jnp.max(condition), jnp.mean(condition))
378331
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
379-
if first_frame_mask is not None:
380-
clean_latents = condition[..., 4:]
381-
latents = first_frame_mask * clean_latents + (1.0 - first_frame_mask) * latents
382332
jax.debug.print("Step {s}: latents_next std={std}, mean={mean}",
383333
s=step,
384334
std=jnp.std(latents),
@@ -394,4 +344,4 @@ def print_step_zero_stats(operands):
394344
lmean=jnp.mean(latents),
395345
lstd=jnp.std(latents))
396346
max_logging.log("Finished fori_loop.")
397-
return latents
347+
return latents

0 commit comments

Comments
 (0)