Skip to content

Commit 3fb489a

Browse files
committed
debug added for shape mismatch
1 parent 9035b2f commit 3fb489a

2 files changed

Lines changed: 23 additions & 1 deletion

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,7 @@ def _prepare_model_inputs_i2v(
602602
if prompt is not None and isinstance(prompt, str):
603603
prompt = [prompt]
604604
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] // num_videos_per_prompt
605+
print(f"[DEBUG PREP] num_prompts={batch_size}, num_videos_per_prompt={num_videos_per_prompt}")
605606
effective_batch_size = batch_size * num_videos_per_prompt
606607

607608
# 1. Encode Prompts
@@ -613,6 +614,8 @@ def _prepare_model_inputs_i2v(
613614
prompt_embeds=prompt_embeds,
614615
negative_prompt_embeds=negative_prompt_embeds,
615616
)
617+
print(f"[DEBUG PREP] prompt_embeds shape after encode_prompt: {prompt_embeds.shape}")
618+
616619

617620
# 2. Encode Image
618621
if image_embeds is None:
@@ -622,9 +625,11 @@ def _prepare_model_inputs_i2v(
622625
else:
623626
images_to_encode = [image, last_image]
624627
image_embeds = self.encode_image(images_to_encode, num_videos_per_prompt=num_videos_per_prompt)
628+
print(f"[DEBUG PREP] image_embeds shape after encode_image: {image_embeds.shape}")
625629

626630
if batch_size > 1:
627631
image_embeds = jnp.tile(image_embeds, (batch_size, 1, 1))
632+
print(f"[DEBUG PREP] image_embeds shape after tile: {image_embeds.shape}")
628633

629634
transformer_dtype = self.config.activations_dtype
630635
image_embeds = image_embeds.astype(transformer_dtype)
@@ -633,11 +638,21 @@ def _prepare_model_inputs_i2v(
633638
negative_prompt_embeds = negative_prompt_embeds.astype(transformer_dtype)
634639

635640
data_sharding = NamedSharding(self.mesh, P(*self.config.data_sharding))
641+
print(f"[DEBUG PREP] data_sharding spec: {self.config.data_sharding}")
636642

637643
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
638644
negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding)
639645
image_embeds = jax.device_put(image_embeds, data_sharding)
640646

647+
print(f"[DEBUG PREP] SHARDED prompt_embeds.shape: {prompt_embeds.shape}")
648+
print(f"[DEBUG PREP] SHARDED image_embeds.shape: {image_embeds.shape}")
649+
print(f"[DEBUG PREP] jax.process_index(): {jax.process_index()}")
650+
651+
if image_embeds.addressable_shards:
652+
print(f"[DEBUG PREP] LOCAL image_embeds shape: {image_embeds.addressable_shards[0].data.shape}")
653+
if prompt_embeds.addressable_shards:
654+
print(f"[DEBUG PREP] LOCAL prompt_embeds shape: {prompt_embeds.addressable_shards[0].data.shape}")
655+
641656
return prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size
642657

643658

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,12 @@ def loop_body(step, vals):
267267
rng, timestep_rng = jax.random.split(rng)
268268
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
269269

270+
print(f"[DEBUG LOOP {step}] on process {jax.process_index()}:")
271+
print(f"[DEBUG LOOP {step}] initial latents local shape: {latents.shape}")
272+
print(f"[DEBUG LOOP {step}] initial prompt_embeds local shape: {prompt_embeds.shape}")
273+
print(f"[DEBUG LOOP {step}] initial image_embeds local shape: {image_embeds.shape}")
274+
275+
270276
latents_input = latents
271277
if do_classifier_free_guidance:
272278
latents_input = jnp.concatenate([latents, latents], axis=0)
@@ -281,7 +287,8 @@ def loop_body(step, vals):
281287
prompt_embeds_input = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
282288
if image_embeds is not None:
283289
image_embeds_input = jnp.concatenate([image_embeds, image_embeds], axis=0)
284-
290+
print(f"[DEBUG LOOP {step}] prompt_embeds_input local shape: {prompt_embeds_input.shape}")
291+
print(f"[DEBUG LOOP {step}] image_embeds_input local shape: {image_embeds_input.shape}")
285292

286293
noise_pred, latents = transformer_forward_pass(
287294
graphdef, sharded_state, rest_of_state,

0 commit comments

Comments
 (0)