diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index eddf8bf94..b6a73ee5d 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ - -from qwix import QtProvider import os import jax import jax.numpy as jnp @@ -292,7 +290,7 @@ def test_get_qt_provider(self, mock_qt_rule): config_int8 = Mock(spec=HyperParameters) config_int8.use_qwix_quantization = True config_int8.quantization = "int8" - provider_int8: QtProvider = WanPipeline.get_qt_provider(config_int8) + provider_int8 = WanPipeline.get_qt_provider(config_int8) self.assertIsNotNone(provider_int8) mock_qt_rule.assert_called_once_with( module_path='.*', @@ -307,7 +305,11 @@ def test_get_qt_provider(self, mock_qt_rule): config_fp8.quantization = "fp8" provider_fp8 = WanPipeline.get_qt_provider(config_fp8) self.assertIsNotNone(provider_fp8) - self.assertEqual(provider_fp8.rules[0].kwargs["weight_qtype"], jnp.float8_e4m3fn) + mock_qt_rule.assert_called_once_with( + module_path='.*', + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn + ) # Case 4: Quantization enabled, type 'fp8_full' mock_qt_rule.reset_mock() @@ -317,7 +319,17 @@ def test_get_qt_provider(self, mock_qt_rule): config_fp8_full.quantization_calibration_method = "absmax" provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) - self.assertEqual(provider_fp8_full.rules[0].kwargs["bwd_qtype"], jnp.float8_e5m2) + mock_qt_rule.assert_called_once_with( + module_path='.*', # Apply to all modules + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, + bwd_use_original_residuals=True, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method = config_fp8_full.quantization_calibration_method, + act_calibration_method = config_fp8_full.quantization_calibration_method, + bwd_calibration_method = config_fp8_full.quantization_calibration_method, + ) # Case 5: Invalid quantization type config_invalid = Mock(spec=HyperParameters) @@ -326,8 +338,8 @@ def test_get_qt_provider(self, mock_qt_rule): self.assertIsNone(WanPipeline.get_qt_provider(config_invalid)) # To test quantize_transformer, we patch its external dependencies - @patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model") - @patch("maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs") + @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') + @patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs') def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model): """ Tests that quantize_transformer calls qwix when quantization is enabled. @@ -358,14 +370,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize # Check that the model returned is the new quantized model self.assertIs(result, mock_quantized_model_obj) - @patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model") + @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') def test_quantize_transformer_disabled(self, mock_quantize_model): """ Tests that quantize_transformer is skipped when quantization is disabled. """ # Setup Mocks mock_config = Mock(spec=HyperParameters) - mock_config.use_qwix_quantization = False # Main condition for this test + mock_config.use_qwix_quantization = False # Main condition for this test mock_model = Mock(spec=WanModel)