Skip to content

Commit edea34d

Browse files
committed
pipeline test fix after text encoder batching
1 parent 4ffd8c7 commit edea34d

1 file changed

Lines changed: 9 additions & 14 deletions

File tree

src/maxdiffusion/tests/ltx2/test_pipeline_ltx2.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -132,29 +132,24 @@ def test_encode_prompt(self, list_embed_mock):
132132
vocoder=MagicMock(),
133133
)
134134

135-
prompt_embeds = jnp.zeros((1, 10, 10))
136-
prompt_attention_mask = jnp.ones((1, 10))
137-
neg_prompt_embeds = jnp.zeros((1, 10, 10))
138-
neg_prompt_attention_mask = jnp.ones((1, 10))
135+
combined_embeds = jnp.zeros((2, 10, 10))
136+
combined_attention_mask = jnp.ones((2, 10))
139137

140-
# Mock return values for positive then negative prompt encoding
141-
list_embed_mock.side_effect = [
142-
(prompt_embeds, prompt_attention_mask),
143-
(neg_prompt_embeds, neg_prompt_attention_mask),
144-
]
138+
# Mock return value for combined prompt encoding
139+
list_embed_mock.return_value = (combined_embeds, combined_attention_mask)
145140

146141
p_e, p_a, n_e, n_a = pipeline.encode_prompt(
147142
prompt=["A cute cat"], negative_prompt=["ugly"], do_classifier_free_guidance=True
148143
)
149144

150145
# Check mock calls
151-
self.assertEqual(list_embed_mock.call_count, 2)
146+
self.assertEqual(list_embed_mock.call_count, 1)
152147

153148
# Check returns
154-
np.testing.assert_array_equal(p_e, prompt_embeds)
155-
np.testing.assert_array_equal(p_a, prompt_attention_mask)
156-
np.testing.assert_array_equal(n_e, neg_prompt_embeds)
157-
np.testing.assert_array_equal(n_a, neg_prompt_attention_mask)
149+
np.testing.assert_array_equal(p_e, combined_embeds[:1])
150+
np.testing.assert_array_equal(p_a, combined_attention_mask[:1])
151+
np.testing.assert_array_equal(n_e, combined_embeds[1:])
152+
np.testing.assert_array_equal(n_a, combined_attention_mask[1:])
158153

159154
@patch("maxdiffusion.pipelines.ltx2.ltx2_pipeline.LTX2Pipeline._get_gemma_prompt_embeds")
160155
def test_encode_prompt_no_cfg(self, list_embed_mock):

0 commit comments

Comments
 (0)