Skip to content

Commit 752ce79

Browse files
committed
pipeline test unittest fix
1 parent 4c0ab6a commit 752ce79

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

src/maxdiffusion/tests/ltx2/test_pipeline_ltx2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,13 @@ def test_check_inputs(self):
118118
with self.assertRaises(ValueError):
119119
pipeline.check_inputs(prompt="test", height=64, width=63)
120120

121+
@patch("maxdiffusion.pipelines.ltx2.ltx2_pipeline.get_tpu_type")
121122
@patch("maxdiffusion.pipelines.ltx2.ltx2_pipeline.LTX2Pipeline._get_gemma_prompt_embeds")
122-
def test_encode_prompt(self, list_embed_mock):
123+
def test_encode_prompt(self, list_embed_mock, mock_get_tpu_type):
123124
"""Test conditional encoding of positive and negative prompts."""
125+
from maxdiffusion.tpu_utils import TpuType
126+
mock_get_tpu_type.return_value = TpuType.TPU_7X
127+
124128
pipeline = LTX2Pipeline(
125129
scheduler=MagicMock(),
126130
vae=MagicMock(),

0 commit comments

Comments
 (0)