Skip to content

Commit e56825f

Browse files
fix rest of unit tests.
1 parent 37df8b9 commit e56825f

3 files changed

Lines changed: 21 additions & 14 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@
4646
from flax.linen import partitioning as nn_partitioning
4747
from flax.training import train_state
4848
from jax.experimental import mesh_utils
49-
from transformers import (
50-
FlaxCLIPTextModel,
51-
FlaxCLIPTextPreTrainedModel
52-
)
49+
from transformers import (FlaxCLIPTextModel, FlaxCLIPTextPreTrainedModel)
5350
from flax import struct
5451
from typing import (
5552
Callable,

src/maxdiffusion/tests/text_encoders_test.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616

1717
import os
1818
import unittest
19+
import pytest
20+
import jax.numpy as jnp
1921
from absl.testing import absltest
2022

2123
from transformers import CLIPTokenizer, FlaxCLIPTextModel
22-
from transformers import T5TokenizerFast, T5EncoderModel
24+
from transformers import T5TokenizerFast, FlaxT5EncoderModel
2325

2426
from ..generate_flux import get_clip_prompt_embeds, get_t5_prompt_embeds
2527

28+
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
2629
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
2730

2831

@@ -32,22 +35,18 @@ class TextEncoderTest(unittest.TestCase):
3235
def setUp(self):
3336
TextEncoderTest.dummy_data = {}
3437

38+
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
3539
def test_flux_t5_text_encoder(self):
3640

37-
text_encoder_2_pt = T5EncoderModel.from_pretrained(
38-
"black-forest-labs/FLUX.1-dev",
39-
subfolder="text_encoder_2",
40-
)
41+
text_encoder = FlaxT5EncoderModel.from_pretrained("ariG23498/t5-v1-1-xxl-flax")
4142

42-
tokenizer_2 = T5TokenizerFast.from_pretrained(
43-
"black-forest-labs/FLUX.1-dev",
44-
subfolder="tokenizer_2",
45-
)
43+
tokenizer_2 = T5TokenizerFast.from_pretrained("ariG23498/t5-v1-1-xxl-flax")
4644

47-
embeds = get_t5_prompt_embeds("A dog on a skateboard", 2, tokenizer_2, text_encoder_2_pt)
45+
embeds = get_t5_prompt_embeds("A dog on a skateboard", 2, tokenizer_2, text_encoder)
4846

4947
assert embeds.shape == (2, 512, 4096)
5048

49+
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
5150
def test_flux_clip_text_encoder(self):
5251

5352
text_encoder = FlaxCLIPTextModel.from_pretrained(
@@ -56,3 +55,7 @@ def test_flux_clip_text_encoder(self):
5655
tokenizer = CLIPTokenizer.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="tokenizer", dtype="bfloat16")
5756
embeds = get_clip_prompt_embeds("A cat riding a skateboard", 2, tokenizer, text_encoder)
5857
assert embeds.shape == (2, 768)
58+
59+
60+
if __name__ == "__main__":
61+
absltest.main()

src/maxdiffusion/tests/vae_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import os
1818
import unittest
19+
import pytest
1920
from absl.testing import absltest
2021

2122
import numpy as np
@@ -27,6 +28,7 @@
2728
from skimage.metrics import structural_similarity as ssim
2829

2930
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
31+
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
3032

3133

3234
class VaeTest(unittest.TestCase):
@@ -35,6 +37,7 @@ class VaeTest(unittest.TestCase):
3537
def setUp(self):
3638
VaeTest.dummy_data = {}
3739

40+
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
3841
def test_flux_vae(self):
3942

4043
img_url = os.path.join(THIS_DIR, "images", "test_hyper_sdxl.png")
@@ -67,3 +70,7 @@ def test_flux_vae(self):
6770
image = np.uint8(image * 255)
6871
ssim_compare = ssim(base_image, image, multichannel=True, channel_axis=-1, data_range=255)
6972
assert ssim_compare >= 0.90
73+
74+
75+
if __name__ == "__main__":
76+
absltest.main()

0 commit comments

Comments
 (0)