From 8c4465d4ddaf3c4a1f9235ec953c4c9c8fe8040b Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Wed, 13 Aug 2025 11:28:01 -0700 Subject: [PATCH 1/3] fix wan unit test bugs --- .../tests/wan_transformer_test.py | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 07da52652..b449d4517 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -from qwix import QtProvider import os import jax import jax.numpy as jnp @@ -277,7 +276,8 @@ def test_wan_model(self): ) assert dummy_output.shape == hidden_states_shape - def test_get_qt_provider(self): + @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.QtRule') + def test_get_qt_provider(self, mock_qt_rule): """ Tests the provider logic for all config branches. """ @@ -290,26 +290,46 @@ def test_get_qt_provider(self): config_int8 = Mock(spec=HyperParameters) config_int8.use_qwix_quantization = True config_int8.quantization = "int8" - provider_int8:QtProvider = WanPipeline.get_qt_provider(config_int8) + provider_int8 = WanPipeline.get_qt_provider(config_int8) self.assertIsNotNone(provider_int8) - self.assertEqual(provider_int8._rules[0].weight_qtype, jnp.int8) + mock_qt_rule.assert_called_once_with( + module_path='.*', + weight_qtype=jnp.int8, + act_qtype=jnp.int8 + ) # Case 3: Quantization enabled, type 'fp8' + mock_qt_rule.reset_mock() config_fp8 = Mock(spec=HyperParameters) config_fp8.use_qwix_quantization = True config_fp8.quantization = "fp8" provider_fp8 = WanPipeline.get_qt_provider(config_fp8) self.assertIsNotNone(provider_fp8) - self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn) - + mock_qt_rule.assert_called_once_with( + module_path='.*', + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn + ) + # Case 4: Quantization enabled, type 'fp8_full' + mock_qt_rule.reset_mock() config_fp8_full = Mock(spec=HyperParameters) config_fp8_full.use_qwix_quantization = True config_fp8_full.quantization = "fp8_full" config_fp8_full.quantization_calibration_method = "absmax" provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) - self.assertEqual(provider_fp8_full.rules[0].kwargs['bwd_qtype'], jnp.float8_e5m2) + mock_qt_rule.assert_called_once_with( + module_path='.*', # Apply to all modules + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, + bwd_use_original_residuals=True, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method = config_fp8_full.quantization_calibration_method, + act_calibration_method = config_fp8_full.quantization_calibration_method, + bwd_calibration_method = config_fp8_full.quantization_calibration_method, + ) # Case 5: Invalid quantization type config_invalid = Mock(spec=HyperParameters) @@ -333,6 +353,8 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize mock_model = Mock(spec=WanModel) mock_pipeline = Mock() mock_mesh = Mock() + mock_mesh.__enter__ = Mock(return_value=None) + mock_mesh.__exit__ = Mock(return_value=None) # Mock the return values of dependencies mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock()) From 55b309c0acc02a907e99e89226f3b362ad285183 Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 13 Aug 2025 18:33:57 +0000 Subject: [PATCH 2/3] line problems --- src/maxdiffusion/tests/wan_transformer_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index b449d4517..b6a73ee5d 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -310,7 +310,7 @@ def test_get_qt_provider(self, mock_qt_rule): weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn ) - + # Case 4: Quantization enabled, type 'fp8_full' mock_qt_rule.reset_mock() config_fp8_full = Mock(spec=HyperParameters) @@ -354,7 +354,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize mock_pipeline = Mock() mock_mesh = Mock() mock_mesh.__enter__ = Mock(return_value=None) - mock_mesh.__exit__ = Mock(return_value=None) + mock_mesh.__exit__ = Mock(return_value=None) # Mock the return values of dependencies mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock()) From cd6eeff7140cf5ab3090db404b008dbc007de2dc Mon Sep 17 00:00:00 2001 From: Sanbao Su Date: Mon, 25 Aug 2025 10:43:12 -0700 Subject: [PATCH 3/3] bug fix under low per_device_batch_size --- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 1659d3bb5..78e3322d1 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -286,7 +286,7 @@ def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline return model max_logging.log("Quantizing transformer with Qwix.") - batch_size = int(config.per_device_batch_size * jax.local_device_count()) + batch_size = jnp.ceil(config.per_device_batch_size * jax.local_device_count()).astype(jnp.int32) latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size) model_inputs = (latents, timesteps, prompt_embeds) with mesh: