@@ -434,18 +434,36 @@ def _get_t5_prompt_embeds(
434434 )
435435 text_input_ids , mask = text_inputs .input_ids , text_inputs .attention_mask
436436 seq_lens = mask .gt (0 ).sum (dim = 1 ).long ()
437+
438+ # DEBUG
439+ print (f"[DEBUG _get_t5_prompt_embeds] seq_lens: { seq_lens .tolist ()} , mask shape: { mask .shape } " )
440+
437441 prompt_embeds = self .text_encoder (text_input_ids , mask ).last_hidden_state
438442 prompt_embeds = [u [:v ] for u , v in zip (prompt_embeds , seq_lens )]
439443 prompt_embeds = torch .stack (
440444 [torch .cat ([u , u .new_zeros (max_sequence_length - u .size (0 ), u .size (1 ))]) for u in prompt_embeds ], dim = 0
441445 )
442446
447+ # Create text attention mask
448+ text_attention_mask = torch .zeros ((batch_size , max_sequence_length ), dtype = torch .long )
449+ for i , seq_len_i in enumerate (seq_lens ):
450+ text_attention_mask [i , :seq_len_i ] = 1
451+
452+ print (f"[DEBUG _get_t5_prompt_embeds] text_attention_mask shape: { text_attention_mask .shape } , sum: { text_attention_mask .sum (dim = 1 ).tolist ()} " )
453+
443454 # duplicate text embeddings for each generation per prompt, using mps friendly method
444455 _ , seq_len , _ = prompt_embeds .shape
445456 prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
446457 prompt_embeds = prompt_embeds .view (batch_size * num_videos_per_prompt , seq_len , - 1 )
458+
459+ # Duplicate mask
460+ text_attention_mask = text_attention_mask .repeat (1 , num_videos_per_prompt )
461+ text_attention_mask = text_attention_mask .view (batch_size * num_videos_per_prompt , max_sequence_length )
462+ text_attention_mask_jax = jnp .array (text_attention_mask .numpy ())
463+
464+ print (f"[DEBUG _get_t5_prompt_embeds] After duplication - mask shape: { text_attention_mask_jax .shape } " )
447465
448- return prompt_embeds
466+ return prompt_embeds , text_attention_mask_jax
449467
450468 def encode_prompt (
451469 self ,
@@ -459,24 +477,31 @@ def encode_prompt(
459477 prompt = [prompt ] if isinstance (prompt , str ) else prompt
460478 batch_size = len (prompt )
461479 if prompt_embeds is None :
462- prompt_embeds = self ._get_t5_prompt_embeds (
480+ prompt_embeds , text_attention_mask = self ._get_t5_prompt_embeds (
463481 prompt = prompt ,
464482 num_videos_per_prompt = num_videos_per_prompt ,
465483 max_sequence_length = max_sequence_length ,
466484 )
467485 prompt_embeds = jnp .array (prompt_embeds .detach ().numpy (), dtype = jnp .float32 )
486+ else :
487+ text_attention_mask = None
468488
469489 if negative_prompt_embeds is None :
470490 negative_prompt = negative_prompt or ""
471491 negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
472- negative_prompt_embeds = self ._get_t5_prompt_embeds (
492+ negative_prompt_embeds , negative_text_attention_mask = self ._get_t5_prompt_embeds (
473493 prompt = negative_prompt ,
474494 num_videos_per_prompt = num_videos_per_prompt ,
475495 max_sequence_length = max_sequence_length ,
476496 )
477497 negative_prompt_embeds = jnp .array (negative_prompt_embeds .detach ().numpy (), dtype = jnp .float32 )
498+ else :
499+ negative_text_attention_mask = None
500+
501+ print (f"[DEBUG encode_prompt] text_attention_mask: { text_attention_mask .shape if text_attention_mask is not None else None } " )
502+ print (f"[DEBUG encode_prompt] negative_text_attention_mask: { negative_text_attention_mask .shape if negative_text_attention_mask is not None else None } " )
478503
479- return prompt_embeds , negative_prompt_embeds
504+ return prompt_embeds , negative_prompt_embeds , text_attention_mask , negative_text_attention_mask
480505
481506 def prepare_latents (
482507 self ,
@@ -687,13 +712,15 @@ def _prepare_model_inputs(
687712 batch_size = len (prompt )
688713
689714 with jax .named_scope ("Encode-Prompt" ):
690- prompt_embeds , negative_prompt_embeds = self .encode_prompt (
715+ prompt_embeds , negative_prompt_embeds , text_attention_mask , negative_text_attention_mask = self .encode_prompt (
691716 prompt = prompt ,
692717 negative_prompt = negative_prompt ,
693718 max_sequence_length = max_sequence_length ,
694719 prompt_embeds = prompt_embeds ,
695720 negative_prompt_embeds = negative_prompt_embeds ,
696721 )
722+
723+ print (f"[DEBUG _prepare_model_inputs] Got masks - text: { text_attention_mask .shape if text_attention_mask is not None else None } , neg: { negative_text_attention_mask .shape if negative_text_attention_mask is not None else None } " )
697724
698725 num_channel_latents = self ._get_num_channel_latents ()
699726 if latents is None :
@@ -715,12 +742,16 @@ def _prepare_model_inputs(
715742 latents = jax .device_put (latents , data_sharding )
716743 prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
717744 negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
745+ if text_attention_mask is not None :
746+ text_attention_mask = jax .device_put (text_attention_mask , data_sharding )
747+ if negative_text_attention_mask is not None :
748+ negative_text_attention_mask = jax .device_put (negative_text_attention_mask , data_sharding )
718749
719750 scheduler_state = self .scheduler .set_timesteps (
720751 self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
721752 )
722753
723- return latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames
754+ return latents , prompt_embeds , negative_prompt_embeds , text_attention_mask , negative_text_attention_mask , scheduler_state , num_frames
724755
725756 @abstractmethod
726757 def __call__ (self , ** kwargs ):
0 commit comments