Skip to content

Commit 649e8c1

Browse files
committed
debug statements added
1 parent 1b67d71 commit 649e8c1

2 files changed

Lines changed: 7 additions & 5 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,14 +403,11 @@ def encode_image(self, image: PipelineImageInput, num_videos_per_prompt: int = 1
403403
image = [image]
404404
image_inputs = self.image_processor(images=image, return_tensors="np")
405405
pixel_values = jnp.array(image_inputs.pixel_values)
406-
max_logging.log(f"[DEBUG ENC] pixel_values shape: {pixel_values.shape}")
407406

408407
image_encoder_output = self.image_encoder(pixel_values, output_hidden_states=True)
409408
image_embeds = image_encoder_output.hidden_states[-2]
410-
max_logging.log(f"[DEBUG ENC] Shape of image_embeds from image_encoder: {image_embeds.shape}")
411-
409+
412410
image_embeds = jnp.repeat(image_embeds, num_videos_per_prompt, axis=0)
413-
max_logging.log(f"[DEBUG ENC] Shape of image_embeds after repeat: {image_embeds.shape}")
414411
return image_embeds
415412

416413

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,12 @@ def __call__(
203203
t0 = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[0]
204204
dummy_noise = jnp.zeros_like(latents)
205205
# This call initializes the internal state arrays
206-
_, scheduler_state = self.scheduler.step(scheduler_state, dummy_noise, t0, latents)
206+
step_output = self.scheduler.step(scheduler_state, dummy_noise, t0, latents)
207+
max_logging.log(f"[DEBUG] scheduler.step output type: {type(step_output)}")
208+
max_logging.log(f"[DEBUG] scheduler.step output value: {step_output}")
209+
_, scheduler_state = step_output
210+
max_logging.log(f"[DEBUG] After prime step: scheduler_state type: {type(scheduler_state)}")
211+
max_logging.log(f"[DEBUG] After prime step: scheduler_state value: {scheduler_state}")
207212
max_logging.log(f"[DEBUG] Scheduler state primed: step_index={scheduler_state.step_index is not None}, last_sample={scheduler_state.last_sample is not None}")
208213

209214
graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...)

0 commit comments

Comments
 (0)