@@ -675,47 +675,6 @@ def load_transformer(
675675
676676
677677
678- @staticmethod
679- def _pack_text_embeds (
680- text_hidden_states : jax .Array ,
681- sequence_lengths : jax .Array ,
682- padding_side : str = "left" ,
683- scale_factor : int = 8 ,
684- eps : float = 1e-6 ,
685- ) -> jax .Array :
686- """
687- Packs and normalizes text encoder hidden states using JAX natively.
688- """
689- batch_size , seq_len , hidden_dim , num_layers = text_hidden_states .shape
690- original_dtype = text_hidden_states .dtype
691-
692- # Create padding mask
693- token_indices = jnp .arange (seq_len )[None , :]
694- if padding_side == "right" :
695- mask = token_indices < sequence_lengths [:, None ]
696- elif padding_side == "left" :
697- start_indices = seq_len - sequence_lengths [:, None ]
698- mask = token_indices >= start_indices
699- else :
700- raise ValueError (f"padding_side must be 'left' or 'right', got { padding_side } " )
701- mask = mask [:, :, None , None ]
702-
703- masked_text_hidden_states = jnp .where (mask , text_hidden_states , 0.0 )
704- num_valid_positions = (sequence_lengths * hidden_dim ).reshape (batch_size , 1 , 1 , 1 )
705- masked_mean = jnp .sum (masked_text_hidden_states , axis = (1 , 2 ), keepdims = True ) / (num_valid_positions + eps )
706-
707- x_min = jnp .min (jnp .where (mask , text_hidden_states , jnp .inf ), axis = (1 , 2 ), keepdims = True )
708- x_max = jnp .max (jnp .where (mask , text_hidden_states , - jnp .inf ), axis = (1 , 2 ), keepdims = True )
709-
710- normalized_hidden_states = (text_hidden_states - masked_mean ) / (x_max - x_min + eps )
711- normalized_hidden_states = normalized_hidden_states * scale_factor
712-
713- normalized_hidden_states = normalized_hidden_states .reshape (batch_size , seq_len , - 1 )
714- mask_flat = jnp .broadcast_to (mask .squeeze (- 1 ), (batch_size , seq_len , hidden_dim * num_layers ))
715- normalized_hidden_states = jnp .where (mask_flat , normalized_hidden_states , 0.0 )
716- normalized_hidden_states = normalized_hidden_states .astype (original_dtype )
717- return normalized_hidden_states
718-
719678 def _get_gemma_prompt_embeds (
720679 self ,
721680 prompt : Union [str , List [str ]],
@@ -755,33 +714,38 @@ def _get_gemma_prompt_embeds(
755714 input_ids = text_input_ids , attention_mask = prompt_attention_mask , output_hidden_states = True
756715 )
757716
758- text_encoder_hidden_states = torch .stack (text_encoder_outputs .hidden_states , dim = - 1 )
759- sequence_lengths = prompt_attention_mask .sum (dim = - 1 )
760-
761- # Convert to JAX arrays to do native JAX math
762- hidden_states_jax = jnp .array (text_encoder_hidden_states .cpu ().to (torch .float32 ).numpy ())
763- sequence_lengths_jax = jnp .array (sequence_lengths .cpu ().numpy ())
764- prompt_attention_mask_jax = jnp .array (prompt_attention_mask .cpu ().numpy ())
765-
717+ text_encoder_hidden_states = text_encoder_outputs .hidden_states
766718 del text_encoder_outputs # Free memory
719+
720+ prompt_embeds_list = []
721+ # Iterate instead of stacking eagerly to avoid 5.7+ GB HBM allocations outside JIT
722+ for state in text_encoder_hidden_states :
723+ state_np = state .cpu ().to (torch .float32 ).numpy ()
724+ prompt_embeds_list .append (jnp .array (state_np , dtype = jnp .bfloat16 ))
725+
726+ prompt_embeds = prompt_embeds_list
767727 del text_encoder_hidden_states # Free PyTorch tensor memory
768728
769- prompt_embeds = self ._pack_text_embeds (
770- hidden_states_jax ,
771- sequence_lengths_jax ,
772- padding_side = self .tokenizer .padding_side ,
773- scale_factor = scale_factor ,
774- )
775- prompt_attention_mask = prompt_attention_mask_jax
729+ prompt_attention_mask = jnp .array (prompt_attention_mask .cpu ().to (torch .float32 ).numpy (), dtype = jnp .bool_ )
776730 else :
777731 raise ValueError ("`text_encoder` is required to encode prompts." )
778732
779733 if dtype is not None :
780- prompt_embeds = prompt_embeds .astype (dtype )
781-
782- _ , seq_len , _ = prompt_embeds .shape
783- prompt_embeds = jnp .repeat (prompt_embeds , num_videos_per_prompt , axis = 0 )
784- prompt_embeds = prompt_embeds .reshape (batch_size * num_videos_per_prompt , seq_len , - 1 )
734+ if isinstance (prompt_embeds , list ):
735+ prompt_embeds = [state .astype (dtype ) for state in prompt_embeds ]
736+ else :
737+ prompt_embeds = prompt_embeds .astype (dtype )
738+
739+ if isinstance (prompt_embeds , list ):
740+ _ , seq_len , _ = prompt_embeds [0 ].shape
741+ prompt_embeds = [
742+ jnp .repeat (state , num_videos_per_prompt , axis = 0 ).reshape (batch_size * num_videos_per_prompt , seq_len , - 1 )
743+ for state in prompt_embeds
744+ ]
745+ else :
746+ _ , seq_len , _ = prompt_embeds .shape
747+ prompt_embeds = jnp .repeat (prompt_embeds , num_videos_per_prompt , axis = 0 )
748+ prompt_embeds = prompt_embeds .reshape (batch_size * num_videos_per_prompt , seq_len , - 1 )
785749
786750 prompt_attention_mask = prompt_attention_mask .reshape (batch_size , - 1 )
787751 prompt_attention_mask = jnp .repeat (prompt_attention_mask , num_videos_per_prompt , axis = 0 )
0 commit comments