@@ -698,15 +698,15 @@ def __init__(
698698 self .caption_projection = NNXCombinedTimestepTextProjEmbeddings (
699699 rngs = rngs ,
700700 in_features = self .caption_channels ,
701- hidden_size = inner_dim ,
701+ hidden_size = self . cross_attention_dim ,
702702 embedding_dim = inner_dim ,
703703 dtype = self .dtype ,
704704 weights_dtype = self .weights_dtype ,
705705 )
706706 self .audio_caption_projection = NNXCombinedTimestepTextProjEmbeddings (
707707 rngs = rngs ,
708708 in_features = self .audio_caption_channels ,
709- hidden_size = audio_inner_dim ,
709+ hidden_size = self . audio_cross_attention_dim ,
710710 embedding_dim = audio_inner_dim ,
711711 dtype = self .dtype ,
712712 weights_dtype = self .weights_dtype ,
@@ -1050,10 +1050,10 @@ def __call__(
10501050 audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate .reshape (batch_size , - 1 , audio_cross_attn_v2a_gate .shape [- 1 ])
10511051
10521052 # 4. Prepare prompt embeddings
1053- encoder_hidden_states = self .caption_projection (encoder_hidden_states )
1053+ encoder_hidden_states = self .caption_projection (encoder_hidden_states , timestep )
10541054 encoder_hidden_states = encoder_hidden_states .reshape (batch_size , - 1 , hidden_states .shape [- 1 ])
10551055
1056- audio_encoder_hidden_states = self .audio_caption_projection (audio_encoder_hidden_states )
1056+ audio_encoder_hidden_states = self .audio_caption_projection (audio_encoder_hidden_states , audio_timestep if audio_timestep is not None else timestep )
10571057 audio_encoder_hidden_states = audio_encoder_hidden_states .reshape (batch_size , - 1 , audio_hidden_states .shape [- 1 ])
10581058
10591059 # 5. Run transformer blocks
0 commit comments