@@ -473,7 +473,11 @@ def _get_t5_prompt_embeds(
473473 prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
474474 prompt_embeds = prompt_embeds .view (batch_size * num_videos_per_prompt , seq_len , - 1 )
475475
476- return prompt_embeds
476+ mask = mask .repeat (1 , num_videos_per_prompt )
477+ mask = mask .view (batch_size * num_videos_per_prompt , seq_len )
478+ mask = jnp .array (mask .detach ().numpy (), dtype = jnp .int32 )
479+
480+ return prompt_embeds , mask
477481
478482 def encode_prompt (
479483 self ,
@@ -483,28 +487,36 @@ def encode_prompt(
483487 max_sequence_length : int = 226 ,
484488 prompt_embeds : jax .Array = None ,
485489 negative_prompt_embeds : jax .Array = None ,
490+ prompt_mask : jax .Array = None ,
491+ negative_prompt_mask : jax .Array = None ,
486492 ):
487493 prompt = [prompt ] if isinstance (prompt , str ) else prompt
488494 if prompt_embeds is None :
489- prompt_embeds = self ._get_t5_prompt_embeds (
495+ prompt_embeds , prompt_mask = self ._get_t5_prompt_embeds (
490496 prompt = prompt ,
491497 num_videos_per_prompt = num_videos_per_prompt ,
492498 max_sequence_length = max_sequence_length ,
493499 )
494500 prompt_embeds = jnp .array (prompt_embeds .detach ().numpy (), dtype = jnp .float32 )
501+ else :
502+ if prompt_mask is None :
503+ prompt_mask = jnp .ones ((prompt_embeds .shape [0 ], prompt_embeds .shape [1 ]), dtype = jnp .int32 )
495504
496505 if negative_prompt_embeds is None :
497506 batch_size = len (prompt_embeds )
498507 negative_prompt = negative_prompt or ""
499508 negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
500- negative_prompt_embeds = self ._get_t5_prompt_embeds (
509+ negative_prompt_embeds , negative_prompt_mask = self ._get_t5_prompt_embeds (
501510 prompt = negative_prompt ,
502511 num_videos_per_prompt = num_videos_per_prompt ,
503512 max_sequence_length = max_sequence_length ,
504513 )
505514 negative_prompt_embeds = jnp .array (negative_prompt_embeds .detach ().numpy (), dtype = jnp .float32 )
515+ else :
516+ if negative_prompt_mask is None :
517+ negative_prompt_mask = jnp .ones ((negative_prompt_embeds .shape [0 ], negative_prompt_embeds .shape [1 ]), dtype = jnp .int32 )
506518
507- return prompt_embeds , negative_prompt_embeds
519+ return prompt_embeds , prompt_mask , negative_prompt_embeds , negative_prompt_mask
508520
509521 def prepare_latents (
510522 self ,
@@ -647,7 +659,7 @@ def _prepare_model_inputs_i2v(
647659 effective_batch_size = batch_size * num_videos_per_prompt
648660
649661 # 1. Encode Prompts
650- prompt_embeds , negative_prompt_embeds = self .encode_prompt (
662+ prompt_embeds , prompt_mask , negative_prompt_embeds , negative_prompt_mask = self .encode_prompt (
651663 prompt = prompt ,
652664 negative_prompt = negative_prompt ,
653665 num_videos_per_prompt = num_videos_per_prompt ,
@@ -691,8 +703,10 @@ def _prepare_model_inputs_i2v(
691703 prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
692704 negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
693705 image_embeds = jax .device_put (image_embeds , data_sharding )
706+ prompt_mask = jax .device_put (prompt_mask , data_sharding )
707+ negative_prompt_mask = jax .device_put (negative_prompt_mask , data_sharding )
694708
695- return prompt_embeds , negative_prompt_embeds , image_embeds , effective_batch_size
709+ return prompt_embeds , negative_prompt_embeds , image_embeds , effective_batch_size , prompt_mask , negative_prompt_mask
696710
697711 def _prepare_model_inputs (
698712 self ,
@@ -724,7 +738,7 @@ def _prepare_model_inputs(
724738 batch_size = len (prompt )
725739
726740 with jax .named_scope ("Encode-Prompt" ):
727- prompt_embeds , negative_prompt_embeds = self .encode_prompt (
741+ prompt_embeds , prompt_mask , negative_prompt_embeds , negative_prompt_mask = self .encode_prompt (
728742 prompt = prompt ,
729743 negative_prompt = negative_prompt ,
730744 max_sequence_length = max_sequence_length ,
@@ -752,12 +766,14 @@ def _prepare_model_inputs(
752766 latents = jax .device_put (latents , data_sharding )
753767 prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
754768 negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
769+ prompt_mask = jax .device_put (prompt_mask , data_sharding )
770+ negative_prompt_mask = jax .device_put (negative_prompt_mask , data_sharding )
755771
756772 scheduler_state = self .scheduler .set_timesteps (
757773 self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
758774 )
759775
760- return latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames
776+ return latents , prompt_embeds , negative_prompt_embeds , prompt_mask , negative_prompt_mask , scheduler_state , num_frames
761777
762778 @abstractmethod
763779 def __call__ (self , ** kwargs ):
@@ -782,6 +798,7 @@ def transformer_forward_pass(
782798 kv_cache = None ,
783799 rotary_emb = None ,
784800 encoder_attention_mask = None ,
801+ text_mask = None ,
785802):
786803 wan_transformer = nnx .merge (graphdef , sharded_state , rest_of_state )
787804 outputs = wan_transformer (
@@ -795,6 +812,7 @@ def transformer_forward_pass(
795812 kv_cache = kv_cache ,
796813 rotary_emb = rotary_emb ,
797814 encoder_attention_mask = encoder_attention_mask ,
815+ text_mask = text_mask ,
798816 )
799817
800818 if return_residual :
@@ -828,6 +846,7 @@ def transformer_forward_pass_full_cfg(
828846 kv_cache = None ,
829847 rotary_emb = None ,
830848 encoder_attention_mask = None ,
849+ text_mask = None ,
831850):
832851 """Full CFG forward pass.
833852
@@ -849,6 +868,7 @@ def transformer_forward_pass_full_cfg(
849868 kv_cache = kv_cache ,
850869 rotary_emb = rotary_emb ,
851870 encoder_attention_mask = encoder_attention_mask ,
871+ text_mask = text_mask ,
852872 )
853873 noise_cond = noise_pred [:bsz ]
854874 noise_uncond = noise_pred [bsz :]
@@ -873,6 +893,7 @@ def transformer_forward_pass_cfg_cache(
873893 kv_cache = None ,
874894 rotary_emb = None ,
875895 encoder_attention_mask = None ,
896+ text_mask = None ,
876897):
877898 """CFG-Cache forward pass with FFT frequency-domain compensation.
878899
@@ -901,6 +922,7 @@ def transformer_forward_pass_cfg_cache(
901922 kv_cache = kv_cache ,
902923 rotary_emb = rotary_emb ,
903924 encoder_attention_mask = encoder_attention_mask ,
925+ text_mask = text_mask ,
904926 )
905927
906928 # FFT over spatial dims (H, W) — last 2 dims of [B, C, F, H, W]
0 commit comments