Skip to content

Commit 43f487c

Browse files
deprecate tests for older stable diffusion models.
1 parent 251dead commit 43f487c

3 files changed

Lines changed: 9 additions & 6 deletions

File tree

src/maxdiffusion/tests/generate_smoke_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def setUp(self):
2222
super().setUp()
2323
Generate.dummy_data = {}
2424

25-
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
25+
@pytest.mark.skip("This test is deprecated and will be removed in a future version.")
2626
def test_sd14_config(self):
2727
img_url = os.path.join(THIS_DIR, "images", "test_gen_sd14.png")
2828
base_image = np.array(Image.open(img_url)).astype(np.uint8)
@@ -42,7 +42,7 @@ def test_sd14_config(self):
4242
assert base_image.shape == test_image.shape
4343
assert ssim_compare >= 0.70
4444

45-
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
45+
@pytest.mark.skip("This test is deprecated and will be removed in a future version.")
4646
def test_sd_2_base_from_gcs(self):
4747
img_url = os.path.join(THIS_DIR, "images", "test_2_base.png")
4848
base_image = np.array(Image.open(img_url)).astype(np.uint8)
@@ -64,7 +64,7 @@ def test_sd_2_base_from_gcs(self):
6464
assert base_image.shape == test_image.shape
6565
assert ssim_compare >= 0.70
6666

67-
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
67+
@pytest.mark.skip("This test is deprecated and will be removed in a future version.")
6868
def test_controlnet(self):
6969
img_url = os.path.join(THIS_DIR, "images", "cnet_test.png")
7070
base_image = np.array(Image.open(img_url)).astype(np.uint8)

src/maxdiffusion/tests/input_pipeline_interface_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import shutil
2121
import subprocess
2222
import unittest
23+
import pytest
2324
from absl.testing import absltest
2425
import numpy as np
2526
import tensorflow as tf
@@ -431,6 +432,7 @@ def test_make_pokemon_iterator_sdxl_cache(self):
431432
config.resolution // vae_scale_factor,
432433
)
433434

435+
@pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace")
434436
def test_make_laion_grain_iterator(self):
435437
try:
436438
subprocess.check_output(
@@ -486,7 +488,8 @@ def test_make_laion_grain_iterator(self):
486488
config.resolution // vae_scale_factor,
487489
8,
488490
)
489-
491+
492+
@pytest.mark.skip("This test is deprecated and will be removed in a future version. Reason: stable diffusion 2 base is no longer in HuggingFace")
490493
def test_make_laion_tfrecord_iterator(self):
491494
pyconfig.initialize(
492495
[

src/maxdiffusion/tests/train_smoke_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_sdxl_config(self):
9696

9797
delete_blobs(os.path.join(output_dir, run_name))
9898

99-
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
99+
@pytest.mark.skip("This test is deprecated and will be removed in a future version.")
100100
def test_dreambooth_orbax(self):
101101
num_class_images = 100
102102
output_dir = "gs://maxdiffusion-github-runner-test-assets"
@@ -149,7 +149,7 @@ def test_dreambooth_orbax(self):
149149
cleanup(class_class_local_dir)
150150
delete_blobs(os.path.join(output_dir, run_name))
151151

152-
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
152+
@pytest.mark.skip("This test is deprecated and will be removed in a future version.")
153153
def test_sd15_orbax(self):
154154
output_dir = "gs://maxdiffusion-github-runner-test-assets"
155155
run_name = "sd15_orbax_smoke_test"

0 commit comments

Comments
 (0)