diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 07da52652..b6a73ee5d 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -from qwix import QtProvider import os import jax import jax.numpy as jnp @@ -277,7 +276,8 @@ def test_wan_model(self): ) assert dummy_output.shape == hidden_states_shape - def test_get_qt_provider(self): + @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.QtRule') + def test_get_qt_provider(self, mock_qt_rule): """ Tests the provider logic for all config branches. """ @@ -290,26 +290,46 @@ def test_get_qt_provider(self): config_int8 = Mock(spec=HyperParameters) config_int8.use_qwix_quantization = True config_int8.quantization = "int8" - provider_int8:QtProvider = WanPipeline.get_qt_provider(config_int8) + provider_int8 = WanPipeline.get_qt_provider(config_int8) self.assertIsNotNone(provider_int8) - self.assertEqual(provider_int8._rules[0].weight_qtype, jnp.int8) + mock_qt_rule.assert_called_once_with( + module_path='.*', + weight_qtype=jnp.int8, + act_qtype=jnp.int8 + ) # Case 3: Quantization enabled, type 'fp8' + mock_qt_rule.reset_mock() config_fp8 = Mock(spec=HyperParameters) config_fp8.use_qwix_quantization = True config_fp8.quantization = "fp8" provider_fp8 = WanPipeline.get_qt_provider(config_fp8) self.assertIsNotNone(provider_fp8) - self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn) + mock_qt_rule.assert_called_once_with( + module_path='.*', + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn + ) # Case 4: Quantization enabled, type 'fp8_full' + mock_qt_rule.reset_mock() config_fp8_full = Mock(spec=HyperParameters) config_fp8_full.use_qwix_quantization = True config_fp8_full.quantization = "fp8_full" config_fp8_full.quantization_calibration_method = "absmax" provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) - self.assertEqual(provider_fp8_full.rules[0].kwargs['bwd_qtype'], jnp.float8_e5m2) + mock_qt_rule.assert_called_once_with( + module_path='.*', # Apply to all modules + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, + bwd_use_original_residuals=True, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method = config_fp8_full.quantization_calibration_method, + act_calibration_method = config_fp8_full.quantization_calibration_method, + bwd_calibration_method = config_fp8_full.quantization_calibration_method, + ) # Case 5: Invalid quantization type config_invalid = Mock(spec=HyperParameters) @@ -333,6 +353,8 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize mock_model = Mock(spec=WanModel) mock_pipeline = Mock() mock_mesh = Mock() + mock_mesh.__enter__ = Mock(return_value=None) + mock_mesh.__exit__ = Mock(return_value=None) # Mock the return values of dependencies mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock())