Skip to content

Commit cc2c288

Browse files
update tests.
1 parent 50a029d commit cc2c288

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import jax
1919
import jax.numpy as jnp
20+
import pytest
2021
import unittest
2122
from absl.testing import absltest
2223
from flax import nnx
@@ -34,6 +35,8 @@
3435
from ..models.normalization_flax import FP32LayerNorm
3536
from ..models.attention_flax import FlaxWanAttention
3637

38+
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
39+
3740
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
3841

3942

@@ -81,6 +84,7 @@ def test_fp32_layer_norm(self):
8184
dummy_output = layer(dummy_hidden_states)
8285
assert dummy_output.shape == dummy_hidden_states.shape
8386

87+
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
8488
def test_wan_time_text_embedding(self):
8589
key = jax.random.key(0)
8690
rngs = nnx.Rngs(key)
@@ -231,6 +235,7 @@ def test_wan_attention(self):
231235
except NotImplementedError:
232236
pass
233237

238+
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
234239
def test_wan_model(self):
235240
pyconfig.initialize(
236241
[

0 commit comments

Comments
 (0)