@@ -440,12 +440,24 @@ def _get_t5_prompt_embeds(
440440 [torch .cat ([u , u .new_zeros (max_sequence_length - u .size (0 ), u .size (1 ))]) for u in prompt_embeds ], dim = 0
441441 )
442442
443+ # Create attention mask: 1 for real tokens, 0 for padded tokens
444+ text_attention_mask = torch .zeros ((batch_size , max_sequence_length ), dtype = torch .long )
445+ for i , seq_len_i in enumerate (seq_lens ):
446+ text_attention_mask [i , :seq_len_i ] = 1
447+
443448 # duplicate text embeddings for each generation per prompt, using mps friendly method
444449 _ , seq_len , _ = prompt_embeds .shape
445450 prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
446451 prompt_embeds = prompt_embeds .view (batch_size * num_videos_per_prompt , seq_len , - 1 )
452+
453+ # Duplicate attention mask for each generation per prompt
454+ text_attention_mask = text_attention_mask .repeat (1 , num_videos_per_prompt )
455+ text_attention_mask = text_attention_mask .view (batch_size * num_videos_per_prompt , max_sequence_length )
456+
457+ # Convert to JAX array
458+ text_attention_mask = jnp .array (text_attention_mask .numpy ())
447459
448- return prompt_embeds
460+ return prompt_embeds , text_attention_mask
449461
450462 def encode_prompt (
451463 self ,
@@ -459,24 +471,28 @@ def encode_prompt(
459471 prompt = [prompt ] if isinstance (prompt , str ) else prompt
460472 batch_size = len (prompt )
461473 if prompt_embeds is None :
462- prompt_embeds = self ._get_t5_prompt_embeds (
474+ prompt_embeds , text_attention_mask = self ._get_t5_prompt_embeds (
463475 prompt = prompt ,
464476 num_videos_per_prompt = num_videos_per_prompt ,
465477 max_sequence_length = max_sequence_length ,
466478 )
467479 prompt_embeds = jnp .array (prompt_embeds .detach ().numpy (), dtype = jnp .float32 )
480+ else :
481+ text_attention_mask = None
468482
469483 if negative_prompt_embeds is None :
470484 negative_prompt = negative_prompt or ""
471485 negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
472- negative_prompt_embeds = self ._get_t5_prompt_embeds (
486+ negative_prompt_embeds , negative_text_attention_mask = self ._get_t5_prompt_embeds (
473487 prompt = negative_prompt ,
474488 num_videos_per_prompt = num_videos_per_prompt ,
475489 max_sequence_length = max_sequence_length ,
476490 )
477491 negative_prompt_embeds = jnp .array (negative_prompt_embeds .detach ().numpy (), dtype = jnp .float32 )
492+ else :
493+ negative_text_attention_mask = None
478494
479- return prompt_embeds , negative_prompt_embeds
495+ return prompt_embeds , negative_prompt_embeds , text_attention_mask , negative_text_attention_mask
480496
481497 def prepare_latents (
482498 self ,
@@ -687,7 +703,7 @@ def _prepare_model_inputs(
687703 batch_size = len (prompt )
688704
689705 with jax .named_scope ("Encode-Prompt" ):
690- prompt_embeds , negative_prompt_embeds = self .encode_prompt (
706+ prompt_embeds , negative_prompt_embeds , text_attention_mask , negative_text_attention_mask = self .encode_prompt (
691707 prompt = prompt ,
692708 negative_prompt = negative_prompt ,
693709 max_sequence_length = max_sequence_length ,
@@ -715,12 +731,16 @@ def _prepare_model_inputs(
715731 latents = jax .device_put (latents , data_sharding )
716732 prompt_embeds = jax .device_put (prompt_embeds , data_sharding )
717733 negative_prompt_embeds = jax .device_put (negative_prompt_embeds , data_sharding )
734+ if text_attention_mask is not None :
735+ text_attention_mask = jax .device_put (text_attention_mask , data_sharding )
736+ if negative_text_attention_mask is not None :
737+ negative_text_attention_mask = jax .device_put (negative_text_attention_mask , data_sharding )
718738
719739 scheduler_state = self .scheduler .set_timesteps (
720740 self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
721741 )
722742
723- return latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames
743+ return latents , prompt_embeds , negative_prompt_embeds , text_attention_mask , negative_text_attention_mask , scheduler_state , num_frames
724744
725745 @abstractmethod
726746 def __call__ (self , ** kwargs ):
@@ -738,9 +758,16 @@ def transformer_forward_pass(
738758 do_classifier_free_guidance ,
739759 guidance_scale ,
740760 encoder_hidden_states_image = None ,
761+ encoder_attention_mask = None ,
741762):
742763 wan_transformer = nnx .merge (graphdef , sharded_state , rest_of_state )
743- noise_pred = wan_transformer (hidden_states = latents , timestep = timestep , encoder_hidden_states = prompt_embeds , encoder_hidden_states_image = encoder_hidden_states_image )
764+ noise_pred = wan_transformer (
765+ hidden_states = latents ,
766+ timestep = timestep ,
767+ encoder_hidden_states = prompt_embeds ,
768+ encoder_hidden_states_image = encoder_hidden_states_image ,
769+ encoder_attention_mask = encoder_attention_mask ,
770+ )
744771 if do_classifier_free_guidance :
745772 bsz = latents .shape [0 ] // 2
746773 noise_cond = noise_pred [:bsz ] # First half = conditional
0 commit comments