@@ -365,6 +365,8 @@ def __call__(
365365 audio_encoder_attention_mask : Optional [jax .Array ] = None ,
366366 a2v_cross_attention_mask : Optional [jax .Array ] = None ,
367367 v2a_cross_attention_mask : Optional [jax .Array ] = None ,
368+ use_a2v_cross_attention : bool = True ,
369+ use_v2a_cross_attention : bool = True ,
368370 perturbation_mask : Optional [jax .Array ] = None ,
369371 ) -> Tuple [jax .Array , jax .Array ]:
370372 batch_size = hidden_states .shape [0 ]
@@ -497,33 +499,35 @@ def __call__(
497499 mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale ) + video_a2v_ca_shift
498500 mod_norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_a2v_ca_scale ) + audio_a2v_ca_shift
499501
500- with jax .named_scope ("Audio-to-Video Cross-Attention" ):
501- a2v_attn_hidden_states = self .audio_to_video_attn (
502- mod_norm_hidden_states ,
503- encoder_hidden_states = mod_norm_audio_hidden_states ,
504- rotary_emb = ca_video_rotary_emb ,
505- k_rotary_emb = ca_audio_rotary_emb ,
506- attention_mask = a2v_cross_attention_mask ,
507- )
508-
509- if perturbation_mask is not None :
510- a2v_attn_hidden_states = mod_norm_audio_hidden_states + perturbation_mask * (a2v_attn_hidden_states - mod_norm_audio_hidden_states )
502+ if use_a2v_cross_attention :
503+ with jax .named_scope ("Audio-to-Video Cross-Attention" ):
504+ a2v_attn_hidden_states = self .audio_to_video_attn (
505+ mod_norm_hidden_states ,
506+ encoder_hidden_states = mod_norm_audio_hidden_states ,
507+ rotary_emb = ca_video_rotary_emb ,
508+ k_rotary_emb = ca_audio_rotary_emb ,
509+ attention_mask = a2v_cross_attention_mask ,
510+ )
511511
512- hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
512+ if perturbation_mask is not None :
513+ a2v_attn_hidden_states = mod_norm_audio_hidden_states + perturbation_mask * (a2v_attn_hidden_states - mod_norm_audio_hidden_states )
514+
515+ hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
513516
514517 # Video-to-Audio Cross Attention: Q: Audio; K,V: Video
515518 mod_norm_hidden_states_v2a = norm_hidden_states * (1 + video_v2a_ca_scale ) + video_v2a_ca_shift
516519 mod_norm_audio_hidden_states_v2a = norm_audio_hidden_states * (1 + audio_v2a_ca_scale ) + audio_v2a_ca_shift
517520
518- with jax .named_scope ("Video-to-Audio Cross-Attention" ):
519- v2a_attn_hidden_states = self .video_to_audio_attn (
520- mod_norm_audio_hidden_states_v2a ,
521- encoder_hidden_states = mod_norm_hidden_states_v2a ,
522- rotary_emb = ca_audio_rotary_emb ,
523- k_rotary_emb = ca_video_rotary_emb ,
524- attention_mask = v2a_cross_attention_mask ,
525- )
526- audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
521+ if use_v2a_cross_attention :
522+ with jax .named_scope ("Video-to-Audio Cross-Attention" ):
523+ v2a_attn_hidden_states = self .video_to_audio_attn (
524+ mod_norm_audio_hidden_states_v2a ,
525+ encoder_hidden_states = mod_norm_hidden_states_v2a ,
526+ rotary_emb = ca_audio_rotary_emb ,
527+ k_rotary_emb = ca_video_rotary_emb ,
528+ attention_mask = v2a_cross_attention_mask ,
529+ )
530+ audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
527531
528532 if modality_mask is not None :
529533 hidden_states = hidden_states * modality_mask
@@ -966,6 +970,7 @@ def __call__(
966970 attention_kwargs : Optional [Dict [str , Any ]] = None ,
967971 use_cross_timestep : bool = False ,
968972 modality_mask : Optional [jax .Array ] = None ,
973+ isolate_modalities : bool = False ,
969974 return_dict : bool = True ,
970975 perturbation_mask : Optional [jax .Array ] = None ,
971976 ) -> Any :
@@ -1111,6 +1116,8 @@ def scan_fn(carry, block_and_mask):
11111116 ca_audio_rotary_emb = audio_cross_attn_rotary_emb ,
11121117 a2v_cross_attention_mask = encoder_attention_mask ,
11131118 v2a_cross_attention_mask = audio_encoder_attention_mask ,
1119+ use_a2v_cross_attention = not isolate_modalities ,
1120+ use_v2a_cross_attention = not isolate_modalities ,
11141121 perturbation_mask = mask ,
11151122 modality_mask = modality_mask ,
11161123 )
@@ -1157,6 +1164,8 @@ def scan_fn(carry, block_and_mask):
11571164 audio_encoder_attention_mask = audio_encoder_attention_mask ,
11581165 a2v_cross_attention_mask = encoder_attention_mask ,
11591166 v2a_cross_attention_mask = audio_encoder_attention_mask ,
1167+ use_a2v_cross_attention = not isolate_modalities ,
1168+ use_v2a_cross_attention = not isolate_modalities ,
11601169 perturbation_mask = mask ,
11611170 )
11621171
0 commit comments