Skip to content

Commit 228b995

Browse files
committed
removed debug statements
1 parent 405d9c2 commit 228b995

2 files changed

Lines changed: 0 additions & 22 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,6 @@ def _prepare_model_inputs_i2v(
602602
if prompt is not None and isinstance(prompt, str):
603603
prompt = [prompt]
604604
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] // num_videos_per_prompt
605-
print(f"[DEBUG PREP] num_prompts={batch_size}, num_videos_per_prompt={num_videos_per_prompt}")
606605
effective_batch_size = batch_size * num_videos_per_prompt
607606

608607
# 1. Encode Prompts
@@ -614,7 +613,6 @@ def _prepare_model_inputs_i2v(
614613
prompt_embeds=prompt_embeds,
615614
negative_prompt_embeds=negative_prompt_embeds,
616615
)
617-
print(f"[DEBUG PREP] prompt_embeds shape after encode_prompt: {prompt_embeds.shape}")
618616

619617

620618
# 2. Encode Image
@@ -625,11 +623,9 @@ def _prepare_model_inputs_i2v(
625623
else:
626624
images_to_encode = [image, last_image]
627625
image_embeds = self.encode_image(images_to_encode, num_videos_per_prompt=num_videos_per_prompt)
628-
print(f"[DEBUG PREP] image_embeds shape after encode_image: {image_embeds.shape}")
629626

630627
if batch_size > 1:
631628
image_embeds = jnp.tile(image_embeds, (batch_size, 1, 1))
632-
print(f"[DEBUG PREP] image_embeds shape after tile: {image_embeds.shape}")
633629

634630
transformer_dtype = self.config.activations_dtype
635631
image_embeds = image_embeds.astype(transformer_dtype)
@@ -638,21 +634,11 @@ def _prepare_model_inputs_i2v(
638634
negative_prompt_embeds = negative_prompt_embeds.astype(transformer_dtype)
639635

640636
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
641-
print(f"[DEBUG PREP] data_sharding spec: {self.config.data_sharding}")
642637

643638
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
644639
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
645640
image_embeds = jax.device_put(image_embeds, data_sharding)
646641

647-
print(f"[DEBUG PREP] SHARDED prompt_embeds.shape: {prompt_embeds.shape}")
648-
print(f"[DEBUG PREP] SHARDED image_embeds.shape: {image_embeds.shape}")
649-
print(f"[DEBUG PREP] jax.process_index(): {jax.process_index()}")
650-
651-
if image_embeds.addressable_shards:
652-
print(f"[DEBUG PREP] LOCAL image_embeds shape: {image_embeds.addressable_shards[0].data.shape}")
653-
if prompt_embeds.addressable_shards:
654-
print(f"[DEBUG PREP] LOCAL prompt_embeds shape: {prompt_embeds.addressable_shards[0].data.shape}")
655-
656642
return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size
657643

658644

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,6 @@ def loop_body(step, vals):
267267
rng, timestep_rng = jax.random.split(rng)
268268
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
269269

270-
print(f"[DEBUG LOOP {step}] on process {jax.process_index()}:")
271-
print(f"[DEBUG LOOP {step}] initial latents local shape: {latents.shape}")
272-
print(f"[DEBUG LOOP {step}] initial prompt_embeds local shape: {prompt_embeds.shape}")
273-
print(f"[DEBUG LOOP {step}] initial image_embeds local shape: {image_embeds.shape}")
274-
275-
276270
latents_input = latents
277271
if do_classifier_free_guidance:
278272
latents_input = jnp.concatenate([latents, latents], axis=0)
@@ -283,8 +277,6 @@ def loop_body(step, vals):
283277

284278
prompt_embeds_input = prompt_embeds
285279
image_embeds_input = image_embeds
286-
print(f"[DEBUG LOOP {step}] prompt_embeds_input local shape: {prompt_embeds_input.shape}")
287-
print(f"[DEBUG LOOP {step}] image_embeds_input local shape: {image_embeds_input.shape}")
288280

289281
noise_pred, latents = transformer_forward_pass(
290282
graphdef, sharded_state, rest_of_state,

0 commit comments

Comments
 (0)