@@ -419,9 +419,6 @@ def _get_t5_prompt_embeds(
419419 num_videos_per_prompt : int = 1 ,
420420 max_sequence_length : int = 226 ,
421421 ):
422- jax .debug .print ("Prompt Shape" )
423- jax .debug .print (f"Length of prompt list: { len (prompt )} " )
424-
425422 prompt = [prompt ] if isinstance (prompt , str ) else prompt
426423 prompt = [prompt_clean (u ) for u in prompt ]
427424 batch_size = len (prompt )
@@ -436,25 +433,29 @@ def _get_t5_prompt_embeds(
436433 return_tensors = "pt" ,
437434 )
438435 text_input_ids , mask = text_inputs .input_ids , text_inputs .attention_mask
439- print ("Text Input IDS" )
440- print (text_input_ids )
441- print (text_input_ids .shape )
442- print ("Mask" )
443- print (mask )
444- print (mask .shape )
445436 seq_lens = mask .gt (0 ).sum (dim = 1 ).long ()
446437 prompt_embeds = self .text_encoder (text_input_ids , mask ).last_hidden_state
447438 prompt_embeds = [u [:v ] for u , v in zip (prompt_embeds , seq_lens )]
448439 prompt_embeds = torch .stack (
449440 [torch .cat ([u , u .new_zeros (max_sequence_length - u .size (0 ), u .size (1 ))]) for u in prompt_embeds ], dim = 0
450441 )
451442
443+ # Create attention mask: 1 for real tokens, 0 for padded tokens
444+ # This mask reflects the actual content after trimming and re-padding with zeros
445+ text_attention_mask = torch .zeros ((batch_size , max_sequence_length ), dtype = torch .int32 , device = mask .device )
446+ for i , length in enumerate (seq_lens ):
447+ text_attention_mask [i , :length ] = 1
448+
452449 # duplicate text embeddings for each generation per prompt, using mps friendly method
453450 _ , seq_len , _ = prompt_embeds .shape
454451 prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
455452 prompt_embeds = prompt_embeds .view (batch_size * num_videos_per_prompt , seq_len , - 1 )
453+
454+ # duplicate attention mask for each generation per prompt
455+ text_attention_mask = text_attention_mask .repeat (1 , num_videos_per_prompt )
456+ text_attention_mask = text_attention_mask .view (batch_size * num_videos_per_prompt , seq_len )
456457
457- return prompt_embeds
458+ return prompt_embeds , text_attention_mask
458459
459460 def encode_prompt (
460461 self ,
@@ -467,25 +468,30 @@ def encode_prompt(
467468 ):
468469 prompt = [prompt ] if isinstance (prompt , str ) else prompt
469470 batch_size = len (prompt )
471+ text_attention_mask = None
472+ negative_text_attention_mask = None
473+
470474 if prompt_embeds is None :
471- prompt_embeds = self ._get_t5_prompt_embeds (
475+ prompt_embeds , text_attention_mask = self ._get_t5_prompt_embeds (
472476 prompt = prompt ,
473477 num_videos_per_prompt = num_videos_per_prompt ,
474478 max_sequence_length = max_sequence_length ,
475479 )
476480 prompt_embeds = jnp .array (prompt_embeds .detach ().numpy (), dtype = jnp .float32 )
481+ text_attention_mask = jnp .array (text_attention_mask .detach ().numpy (), dtype = jnp .int32 )
477482
478483 if negative_prompt_embeds is None :
479484 negative_prompt = negative_prompt or ""
480485 negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
481- negative_prompt_embeds = self ._get_t5_prompt_embeds (
486+ negative_prompt_embeds , negative_text_attention_mask = self ._get_t5_prompt_embeds (
482487 prompt = negative_prompt ,
483488 num_videos_per_prompt = num_videos_per_prompt ,
484489 max_sequence_length = max_sequence_length ,
485490 )
486491 negative_prompt_embeds = jnp .array (negative_prompt_embeds .detach ().numpy (), dtype = jnp .float32 )
492+ negative_text_attention_mask = jnp .array (negative_text_attention_mask .detach ().numpy (), dtype = jnp .int32 )
487493
488- return prompt_embeds , negative_prompt_embeds
494+ return prompt_embeds , negative_prompt_embeds , text_attention_mask , negative_text_attention_mask
489495
490496 def prepare_latents (
491497 self ,
@@ -617,7 +623,7 @@ def _prepare_model_inputs_i2v(
617623 effective_batch_size = batch_size * num_videos_per_prompt
618624
619625 # 1. Encode Prompts
620- prompt_embeds , negative_prompt_embeds = self .encode_prompt (
626+ prompt_embeds , negative_prompt_embeds , text_attention_mask , negative_text_attention_mask = self .encode_prompt (
621627 prompt = prompt ,
622628 negative_prompt = negative_prompt ,
623629 num_videos_per_prompt = num_videos_per_prompt ,
@@ -662,8 +668,10 @@ def _prepare_model_inputs_i2v(
662668 prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
663669 negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
664670 image_embeds = jax .device_put (image_embeds , data_sharding )
671+ text_attention_mask = jax .device_put (text_attention_mask , data_sharding )
672+ negative_text_attention_mask = jax .device_put (negative_text_attention_mask , data_sharding )
665673
666- return prompt_embeds , negative_prompt_embeds , image_embeds , effective_batch_size
674+ return prompt_embeds , negative_prompt_embeds , image_embeds , text_attention_mask , negative_text_attention_mask , effective_batch_size
667675
668676
669677 def _prepare_model_inputs (
@@ -696,7 +704,7 @@ def _prepare_model_inputs(
696704 batch_size = len (prompt )
697705
698706 with jax .named_scope ("Encode-Prompt" ):
699- prompt_embeds , negative_prompt_embeds = self .encode_prompt (
707+ prompt_embeds , negative_prompt_embeds , text_attention_mask , negative_text_attention_mask = self .encode_prompt (
700708 prompt = prompt ,
701709 negative_prompt = negative_prompt ,
702710 max_sequence_length = max_sequence_length ,
@@ -724,12 +732,14 @@ def _prepare_model_inputs(
724732 latents = jax .device_put (latents , data_sharding )
725733 prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
726734 negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
735+ text_attention_mask = jax .device_put (text_attention_mask , data_sharding )
736+ negative_text_attention_mask = jax .device_put (negative_text_attention_mask , data_sharding )
727737
728738 scheduler_state = self .scheduler .set_timesteps (
729739 self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
730740 )
731741
732- return latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames
742+ return latents , prompt_embeds , negative_prompt_embeds , text_attention_mask , negative_text_attention_mask , scheduler_state , num_frames
733743
734744 @abstractmethod
735745 def __call__ (self , ** kwargs ):
@@ -747,9 +757,10 @@ def transformer_forward_pass(
747757 do_classifier_free_guidance ,
748758 guidance_scale ,
749759 encoder_hidden_states_image = None ,
760+ text_attention_mask = None ,
750761):
751762 wan_transformer = nnx .merge (graphdef , sharded_state , rest_of_state )
752- noise_pred = wan_transformer (hidden_states = latents , timestep = timestep , encoder_hidden_states = prompt_embeds , encoder_hidden_states_image = encoder_hidden_states_image )
763+ noise_pred = wan_transformer (hidden_states = latents , timestep = timestep , encoder_hidden_states = prompt_embeds , encoder_hidden_states_image = encoder_hidden_states_image , text_attention_mask = text_attention_mask )
753764 if do_classifier_free_guidance :
754765 bsz = latents .shape [0 ] // 2
755766 noise_cond = noise_pred [:bsz ] # First half = conditional
0 commit comments