@@ -677,20 +677,20 @@ def load_transformer(
677677
678678 @staticmethod
679679 def _pack_text_embeds (
680- text_hidden_states : torch . Tensor ,
681- sequence_lengths : torch . Tensor ,
680+ text_hidden_states : jax . Array ,
681+ sequence_lengths : jax . Array ,
682682 padding_side : str = "left" ,
683683 scale_factor : int = 8 ,
684684 eps : float = 1e-6 ,
685- ) -> torch . Tensor :
685+ ) -> jax . Array :
686686 """
687- Packs and normalizes text encoder hidden states using PyTorch to save device HBM .
687+ Packs and normalizes text encoder hidden states using JAX natively .
688688 """
689689 batch_size , seq_len , hidden_dim , num_layers = text_hidden_states .shape
690690 original_dtype = text_hidden_states .dtype
691691
692692 # Create padding mask
693- token_indices = torch .arange (seq_len , device = text_hidden_states . device ). unsqueeze ( 0 )
693+ token_indices = jnp .arange (seq_len )[ None , :]
694694 if padding_side == "right" :
695695 mask = token_indices < sequence_lengths [:, None ]
696696 elif padding_side == "left" :
@@ -700,20 +700,20 @@ def _pack_text_embeds(
700700 raise ValueError (f"padding_side must be 'left' or 'right', got { padding_side } " )
701701 mask = mask [:, :, None , None ]
702702
703- masked_text_hidden_states = text_hidden_states . masked_fill ( ~ mask , 0.0 )
704- num_valid_positions = (sequence_lengths * hidden_dim ).view (batch_size , 1 , 1 , 1 )
705- masked_mean = masked_text_hidden_states .sum (dim = (1 , 2 ), keepdim = True ) / (num_valid_positions + eps )
703+ masked_text_hidden_states = jnp . where ( mask , text_hidden_states , 0.0 )
704+ num_valid_positions = (sequence_lengths * hidden_dim ).reshape (batch_size , 1 , 1 , 1 )
705+ masked_mean = jnp .sum (masked_text_hidden_states , axis = (1 , 2 ), keepdims = True ) / (num_valid_positions + eps )
706706
707- x_min = text_hidden_states . masked_fill ( ~ mask , float ( " inf" )). amin ( dim = (1 , 2 ), keepdim = True )
708- x_max = text_hidden_states . masked_fill ( ~ mask , float ( "- inf" )). amax ( dim = (1 , 2 ), keepdim = True )
707+ x_min = jnp . min ( jnp . where ( mask , text_hidden_states , jnp . inf ), axis = (1 , 2 ), keepdims = True )
708+ x_max = jnp . max ( jnp . where ( mask , text_hidden_states , - jnp . inf ), axis = (1 , 2 ), keepdims = True )
709709
710710 normalized_hidden_states = (text_hidden_states - masked_mean ) / (x_max - x_min + eps )
711711 normalized_hidden_states = normalized_hidden_states * scale_factor
712712
713- normalized_hidden_states = normalized_hidden_states .flatten ( 2 )
714- mask_flat = mask .squeeze (- 1 ). expand ( - 1 , - 1 , hidden_dim * num_layers )
715- normalized_hidden_states = normalized_hidden_states . masked_fill ( ~ mask_flat , 0.0 )
716- normalized_hidden_states = normalized_hidden_states .to ( dtype = original_dtype )
713+ normalized_hidden_states = normalized_hidden_states .reshape ( batch_size , seq_len , - 1 )
714+ mask_flat = jnp . broadcast_to ( mask .squeeze (- 1 ), ( batch_size , seq_len , hidden_dim * num_layers ) )
715+ normalized_hidden_states = jnp . where ( mask_flat , normalized_hidden_states , 0.0 )
716+ normalized_hidden_states = normalized_hidden_states .astype ( original_dtype )
717717 return normalized_hidden_states
718718
719719 def _get_gemma_prompt_embeds (
@@ -733,7 +733,6 @@ def _get_gemma_prompt_embeds(
733733 self .tokenizer .pad_token = self .tokenizer .eos_token
734734
735735 prompt = [p .strip () for p in prompt ]
736- # Return Numpy tensors to be compatible with JAX if no text encoder, else PyTorch
737736
738737 if self .text_encoder is not None :
739738 # PyTorch Text Encoder
@@ -748,49 +747,41 @@ def _get_gemma_prompt_embeds(
748747 text_input_ids = text_inputs .input_ids
749748 prompt_attention_mask = text_inputs .attention_mask
750749
751- # Move to device if needed (assuming text_encoder is on correct device or CPU)
752- # For now, keep on CPU or same device as model
753750 text_input_ids = text_input_ids .to (self .text_encoder .device )
754751 prompt_attention_mask = prompt_attention_mask .to (self .text_encoder .device )
755752
756- max_logging .log (f"DEBUG: text_encoder is on { self .text_encoder .device } " )
757- max_logging .log (f"DEBUG: text_input_ids is on { text_input_ids .device } " )
758-
759753 with torch .no_grad ():
760754 text_encoder_outputs = self .text_encoder (
761755 input_ids = text_input_ids , attention_mask = prompt_attention_mask , output_hidden_states = True
762756 )
763757
764- text_encoder_hidden_states = text_encoder_outputs .hidden_states
765- del text_encoder_outputs # Free memory
758+ text_encoder_hidden_states = torch . stack ( text_encoder_outputs .hidden_states , dim = - 1 )
759+ sequence_lengths = prompt_attention_mask . sum ( dim = - 1 )
766760
767- prompt_embeds_list = []
768- for state in text_encoder_hidden_states :
769- state_np = state . cpu (). to ( torch . float32 ).numpy ()
770- prompt_embeds_list . append ( jnp .array (state_np , dtype = jnp . bfloat16 ))
771-
772- prompt_embeds = prompt_embeds_list
761+ # Convert to JAX arrays to do native JAX math
762+ hidden_states_jax = jnp . array ( text_encoder_hidden_states . cpu (). to ( torch . float32 ). numpy ())
763+ sequence_lengths_jax = jnp . array ( sequence_lengths . cpu ( ).numpy () )
764+ prompt_attention_mask_jax = jnp .array (prompt_attention_mask . cpu (). numpy ( ))
765+
766+ del text_encoder_outputs # Free memory
773767 del text_encoder_hidden_states # Free PyTorch tensor memory
774768
775- prompt_attention_mask = jnp .array (prompt_attention_mask .cpu ().to (torch .float32 ).numpy (), dtype = jnp .bfloat16 )
769+ prompt_embeds = self ._pack_text_embeds (
770+ hidden_states_jax ,
771+ sequence_lengths_jax ,
772+ padding_side = self .tokenizer .padding_side ,
773+ scale_factor = scale_factor ,
774+ )
775+ prompt_attention_mask = prompt_attention_mask_jax
776776 else :
777777 raise ValueError ("`text_encoder` is required to encode prompts." )
778+
778779 if dtype is not None :
779- if isinstance (prompt_embeds , list ):
780- prompt_embeds = [state .astype (dtype ) for state in prompt_embeds ]
781- else :
782- prompt_embeds = prompt_embeds .astype (dtype )
783-
784- if isinstance (prompt_embeds , list ):
785- _ , seq_len , _ = prompt_embeds [0 ].shape
786- prompt_embeds = [
787- jnp .repeat (state , num_videos_per_prompt , axis = 0 ).reshape (batch_size * num_videos_per_prompt , seq_len , - 1 )
788- for state in prompt_embeds
789- ]
790- else :
791- _ , seq_len , _ = prompt_embeds .shape
792- prompt_embeds = jnp .repeat (prompt_embeds , num_videos_per_prompt , axis = 0 )
793- prompt_embeds = prompt_embeds .reshape (batch_size * num_videos_per_prompt , seq_len , - 1 )
780+ prompt_embeds = prompt_embeds .astype (dtype )
781+
782+ _ , seq_len , _ = prompt_embeds .shape
783+ prompt_embeds = jnp .repeat (prompt_embeds , num_videos_per_prompt , axis = 0 )
784+ prompt_embeds = prompt_embeds .reshape (batch_size * num_videos_per_prompt , seq_len , - 1 )
794785
795786 prompt_attention_mask = prompt_attention_mask .reshape (batch_size , - 1 )
796787 prompt_attention_mask = jnp .repeat (prompt_attention_mask , num_videos_per_prompt , axis = 0 )
0 commit comments