@@ -81,7 +81,7 @@ def __init__(
8181 in_features = input_dim , out_features = a_dim , use_bias = proj_bias , rngs = rngs
8282 )
8383
84- self .video_connector = Embeddings1DConnector (
84+ self .video_embeddings_connector = Embeddings1DConnector (
8585 input_dim = v_dim ,
8686 heads = video_connector_num_attention_heads ,
8787 head_dim = video_connector_attention_head_dim ,
@@ -96,7 +96,7 @@ def __init__(
9696 rngs = rngs ,
9797 gated_attn = video_gated_attn ,
9898 )
99- self .audio_connector = Embeddings1DConnector (
99+ self .audio_embeddings_connector = Embeddings1DConnector (
100100 input_dim = a_dim ,
101101 heads = audio_connector_num_attention_heads ,
102102 head_dim = audio_connector_attention_head_dim ,
@@ -193,16 +193,16 @@ def __call__(
193193 # Using self.caption_channels if available, or fallback to config or 3840
194194 cap_channels = getattr (self , "caption_channels" , getattr (self .config , "caption_channels" , 3840 ))
195195
196- video_scale_factor = jnp .sqrt (self .video_connector .dim / cap_channels )
196+ video_scale_factor = jnp .sqrt (self .video_embeddings_connector .dim / cap_channels )
197197 video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
198- audio_scale_factor = jnp .sqrt (self .audio_connector .dim / cap_channels )
198+ audio_scale_factor = jnp .sqrt (self .audio_embeddings_connector .dim / cap_channels )
199199 audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
200200
201201 video_text_emb_proj = self .video_text_proj_in (video_norm_text_emb )
202202 audio_text_emb_proj = self .audio_text_proj_in (audio_norm_text_emb )
203203
204- video_embeds , new_attention_mask = self .video_connector (video_text_emb_proj , attention_mask )
205- audio_embeds , _ = self .audio_connector (audio_text_emb_proj , attention_mask )
204+ video_embeds , new_attention_mask = self .video_embeddings_connector (video_text_emb_proj , attention_mask )
205+ audio_embeds , _ = self .audio_embeddings_connector (audio_text_emb_proj , attention_mask )
206206 else :
207207 # 1. Shared Feature Extraction
208208 features = self .feature_extractor (hidden_states , attention_mask )
0 commit comments