File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1717import os
1818import jax
1919import jax .numpy as jnp
20+ import pytest
2021import unittest
2122from absl .testing import absltest
2223from flax import nnx
3435from ..models .normalization_flax import FP32LayerNorm
3536from ..models .attention_flax import FlaxWanAttention
3637
38+ IN_GITHUB_ACTIONS = os .getenv ("GITHUB_ACTIONS" ) == "true"
39+
3740THIS_DIR = os .path .dirname (os .path .abspath (__file__ ))
3841
3942
@@ -81,6 +84,7 @@ def test_fp32_layer_norm(self):
8184 dummy_output = layer (dummy_hidden_states )
8285 assert dummy_output .shape == dummy_hidden_states .shape
8386
87+ @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
8488 def test_wan_time_text_embedding (self ):
8589 key = jax .random .key (0 )
8690 rngs = nnx .Rngs (key )
@@ -231,6 +235,7 @@ def test_wan_attention(self):
231235 except NotImplementedError :
232236 pass
233237
238+ @pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
234239 def test_wan_model (self ):
235240 pyconfig .initialize (
236241 [
You can’t perform that action at this time.
0 commit comments