diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index ae638e059..3d1327c3b 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -37,6 +37,11 @@ from ..models.attention_flax import FlaxWanAttention from maxdiffusion.pyconfig import HyperParameters from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline +import qwix +import flax + +flax.config.update('flax_always_shard_variable', False) +RealQtRule = qwix.QtRule IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" @@ -282,6 +287,10 @@ def test_get_qt_provider(self, mock_qt_rule): """ Tests the provider logic for all config branches. """ + def create_real_rule_instance(*args, **kwargs): + return RealQtRule(*args, **kwargs) + mock_qt_rule.side_effect = create_real_rule_instance + # Case 1: Quantization disabled config_disabled = Mock(spec=HyperParameters) config_disabled.use_qwix_quantization = False @@ -301,7 +310,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_fp8 = Mock(spec=HyperParameters) config_fp8.use_qwix_quantization = True config_fp8.quantization = "fp8" - config_int8.qwix_module_path = ".*" + config_fp8.qwix_module_path = ".*" provider_fp8 = WanPipeline.get_qt_provider(config_fp8) self.assertIsNotNone(provider_fp8) mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, op_names=("dot_general","einsum", "conv_general_dilated")) @@ -312,7 +321,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_fp8_full.use_qwix_quantization = True config_fp8_full.quantization = "fp8_full" config_fp8_full.quantization_calibration_method = "absmax" - config_int8.qwix_module_path = ".*" + config_fp8_full.qwix_module_path = ".*" provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) expected_calls = [ @@ -361,6 +370,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize mock_config.quantization = "fp8_full" mock_config.qwix_module_path = ".*" mock_config.per_device_batch_size = 1 + mock_config.quantization_calibration_method = "absmax" mock_model = Mock(spec=WanModel) mock_pipeline = Mock() diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 7b131e7fb..66d8dce9d 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -46,12 +46,13 @@ from ..models.wan.wan_utils import load_wan_vae from ..utils import load_video from ..video_processor import VideoProcessor +import flax THIS_DIR = os.path.dirname(os.path.abspath(__file__)) CACHE_T = 2 - +flax.config.update('flax_always_shard_variable', False) class TorchWanRMS_norm(nn.Module): r""" A custom RMS normalization layer.