@@ -602,7 +602,6 @@ def _prepare_model_inputs_i2v(
602602 if prompt is not None and isinstance (prompt , str ):
603603 prompt = [prompt ]
604604 batch_size = len (prompt ) if prompt is not None else prompt_embeds .shape [0 ] // num_videos_per_prompt
605- print (f"[DEBUG PREP] num_prompts={ batch_size } , num_videos_per_prompt={ num_videos_per_prompt } " )
606605 effective_batch_size = batch_size * num_videos_per_prompt
607606
608607 # 1. Encode Prompts
@@ -614,7 +613,6 @@ def _prepare_model_inputs_i2v(
614613 prompt_embeds = prompt_embeds ,
615614 negative_prompt_embeds = negative_prompt_embeds ,
616615 )
617- print (f"[DEBUG PREP] prompt_embeds shape after encode_prompt: { prompt_embeds .shape } " )
618616
619617
620618 # 2. Encode Image
@@ -625,11 +623,9 @@ def _prepare_model_inputs_i2v(
625623 else :
626624 images_to_encode = [image , last_image ]
627625 image_embeds = self .encode_image (images_to_encode , num_videos_per_prompt = num_videos_per_prompt )
628- print (f"[DEBUG PREP] image_embeds shape after encode_image: { image_embeds .shape } " )
629626
630627 if batch_size > 1 :
631628 image_embeds = jnp .tile (image_embeds , (batch_size , 1 , 1 ))
632- print (f"[DEBUG PREP] image_embeds shape after tile: { image_embeds .shape } " )
633629
634630 transformer_dtype = self .config .activations_dtype
635631 image_embeds = image_embeds .astype (transformer_dtype )
@@ -638,21 +634,11 @@ def _prepare_model_inputs_i2v(
638634 negative_prompt_embeds = negative_prompt_embeds .astype (transformer_dtype )
639635
640636 data_sharding = NamedSharding (self .mesh , P (* self .config .data_sharding ))
641- print (f"[DEBUG PREP] data_sharding spec: { self .config .data_sharding } " )
642637
643638 prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
644639 negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
645640 image_embeds = jax .device_put (image_embeds , data_sharding )
646641
647- print (f"[DEBUG PREP] SHARDED prompt_embeds.shape: { prompt_embeds .shape } " )
648- print (f"[DEBUG PREP] SHARDED image_embeds.shape: { image_embeds .shape } " )
649- print (f"[DEBUG PREP] jax.process_index(): { jax .process_index ()} " )
650-
651- if image_embeds .addressable_shards :
652- print (f"[DEBUG PREP] LOCAL image_embeds shape: { image_embeds .addressable_shards [0 ].data .shape } " )
653- if prompt_embeds .addressable_shards :
654- print (f"[DEBUG PREP] LOCAL prompt_embeds shape: { prompt_embeds .addressable_shards [0 ].data .shape } " )
655-
656642 return prompt_embeds , negative_prompt_embeds , image_embeds , effective_batch_size
657643
658644
0 commit comments