1616
1717import os
1818import unittest
19+ import pytest
20+ import jax .numpy as jnp
1921from absl .testing import absltest
2022
2123from transformers import CLIPTokenizer , FlaxCLIPTextModel
22- from transformers import T5TokenizerFast , T5EncoderModel
24+ from transformers import T5TokenizerFast , FlaxT5EncoderModel
2325
2426from ..generate_flux import get_clip_prompt_embeds , get_t5_prompt_embeds
2527
28+ IN_GITHUB_ACTIONS = os .getenv ("GITHUB_ACTIONS" ) == "true"
2629THIS_DIR = os .path .dirname (os .path .abspath (__file__ ))
2730
2831
@@ -32,22 +35,18 @@ class TextEncoderTest(unittest.TestCase):
3235 def setUp (self ):
3336 TextEncoderTest .dummy_data = {}
3437
38+ @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
3539 def test_flux_t5_text_encoder (self ):
3640
37- text_encoder_2_pt = T5EncoderModel .from_pretrained (
38- "black-forest-labs/FLUX.1-dev" ,
39- subfolder = "text_encoder_2" ,
40- )
41+ text_encoder = FlaxT5EncoderModel .from_pretrained ("ariG23498/t5-v1-1-xxl-flax" )
4142
42- tokenizer_2 = T5TokenizerFast .from_pretrained (
43- "black-forest-labs/FLUX.1-dev" ,
44- subfolder = "tokenizer_2" ,
45- )
43+ tokenizer_2 = T5TokenizerFast .from_pretrained ("ariG23498/t5-v1-1-xxl-flax" )
4644
47- embeds = get_t5_prompt_embeds ("A dog on a skateboard" , 2 , tokenizer_2 , text_encoder_2_pt )
45+ embeds = get_t5_prompt_embeds ("A dog on a skateboard" , 2 , tokenizer_2 , text_encoder )
4846
4947 assert embeds .shape == (2 , 512 , 4096 )
5048
49+ @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
5150 def test_flux_clip_text_encoder (self ):
5251
5352 text_encoder = FlaxCLIPTextModel .from_pretrained (
@@ -56,3 +55,7 @@ def test_flux_clip_text_encoder(self):
5655 tokenizer = CLIPTokenizer .from_pretrained ("black-forest-labs/FLUX.1-dev" , subfolder = "tokenizer" , dtype = "bfloat16" )
5756 embeds = get_clip_prompt_embeds ("A cat riding a skateboard" , 2 , tokenizer , text_encoder )
5857 assert embeds .shape == (2 , 768 )
58+
59+
60+ if __name__ == "__main__" :
61+ absltest .main ()
0 commit comments