@@ -132,29 +132,24 @@ def test_encode_prompt(self, list_embed_mock):
132132 vocoder = MagicMock (),
133133 )
134134
135- prompt_embeds = jnp .zeros ((1 , 10 , 10 ))
136- prompt_attention_mask = jnp .ones ((1 , 10 ))
137- neg_prompt_embeds = jnp .zeros ((1 , 10 , 10 ))
138- neg_prompt_attention_mask = jnp .ones ((1 , 10 ))
135+ combined_embeds = jnp .zeros ((2 , 10 , 10 ))
136+ combined_attention_mask = jnp .ones ((2 , 10 ))
139137
140- # Mock return values for positive then negative prompt encoding
141- list_embed_mock .side_effect = [
142- (prompt_embeds , prompt_attention_mask ),
143- (neg_prompt_embeds , neg_prompt_attention_mask ),
144- ]
138+ # Mock return value for combined prompt encoding
139+ list_embed_mock .return_value = (combined_embeds , combined_attention_mask )
145140
146141 p_e , p_a , n_e , n_a = pipeline .encode_prompt (
147142 prompt = ["A cute cat" ], negative_prompt = ["ugly" ], do_classifier_free_guidance = True
148143 )
149144
150145 # Check mock calls
151- self .assertEqual (list_embed_mock .call_count , 2 )
146+ self .assertEqual (list_embed_mock .call_count , 1 )
152147
153148 # Check returns
154- np .testing .assert_array_equal (p_e , prompt_embeds )
155- np .testing .assert_array_equal (p_a , prompt_attention_mask )
156- np .testing .assert_array_equal (n_e , neg_prompt_embeds )
157- np .testing .assert_array_equal (n_a , neg_prompt_attention_mask )
149+ np .testing .assert_array_equal (p_e , combined_embeds [: 1 ] )
150+ np .testing .assert_array_equal (p_a , combined_attention_mask [: 1 ] )
151+ np .testing .assert_array_equal (n_e , combined_embeds [ 1 :] )
152+ np .testing .assert_array_equal (n_a , combined_attention_mask [ 1 :] )
158153
159154 @patch ("maxdiffusion.pipelines.ltx2.ltx2_pipeline.LTX2Pipeline._get_gemma_prompt_embeds" )
160155 def test_encode_prompt_no_cfg (self , list_embed_mock ):
0 commit comments