@@ -602,6 +602,7 @@ def _prepare_model_inputs_i2v(
602602 if prompt is not None and isinstance (prompt , str ):
603603 prompt = [prompt ]
604604 batch_size = len (prompt ) if prompt is not None else prompt_embeds .shape [0 ] // num_videos_per_prompt
605+ print (f"[DEBUG PREP] num_prompts={ batch_size } , num_videos_per_prompt={ num_videos_per_prompt } " )
605606 effective_batch_size = batch_size * num_videos_per_prompt
606607
607608 # 1. Encode Prompts
@@ -613,6 +614,8 @@ def _prepare_model_inputs_i2v(
613614 prompt_embeds = prompt_embeds ,
614615 negative_prompt_embeds = negative_prompt_embeds ,
615616 )
617+ print (f"[DEBUG PREP] prompt_embeds shape after encode_prompt: { prompt_embeds .shape } " )
618+
616619
617620 # 2. Encode Image
618621 if image_embeds is None :
@@ -622,9 +625,11 @@ def _prepare_model_inputs_i2v(
622625 else :
623626 images_to_encode = [image , last_image ]
624627 image_embeds = self .encode_image (images_to_encode , num_videos_per_prompt = num_videos_per_prompt )
628+ print (f"[DEBUG PREP] image_embeds shape after encode_image: { image_embeds .shape } " )
625629
626630 if batch_size > 1 :
627631 image_embeds = jnp .tile (image_embeds , (batch_size , 1 , 1 ))
632+ print (f"[DEBUG PREP] image_embeds shape after tile: { image_embeds .shape } " )
628633
629634 transformer_dtype = self .config .activations_dtype
630635 image_embeds = image_embeds .astype (transformer_dtype )
@@ -633,11 +638,21 @@ def _prepare_model_inputs_i2v(
633638 negative_prompt_embeds = negative_prompt_embeds .astype (transformer_dtype )
634639
635640 data_sharding = NamedSharding (self .mesh , P (* self .config .data_sharding ))
641+ print (f"[DEBUG PREP] data_sharding spec: { self .config .data_sharding } " )
636642
637643 prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
638644 negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
639645 image_embeds = jax .device_put (image_embeds , data_sharding )
640646
647+ print (f"[DEBUG PREP] SHARDED prompt_embeds.shape: { prompt_embeds .shape } " )
648+ print (f"[DEBUG PREP] SHARDED image_embeds.shape: { image_embeds .shape } " )
649+ print (f"[DEBUG PREP] jax.process_index(): { jax .process_index ()} " )
650+
651+ if image_embeds .addressable_shards :
652+ print (f"[DEBUG PREP] LOCAL image_embeds shape: { image_embeds .addressable_shards [0 ].data .shape } " )
653+ if prompt_embeds .addressable_shards :
654+ print (f"[DEBUG PREP] LOCAL prompt_embeds shape: { prompt_embeds .addressable_shards [0 ].data .shape } " )
655+
641656 return prompt_embeds , negative_prompt_embeds , image_embeds , effective_batch_size
642657
643658
0 commit comments