@@ -278,7 +278,8 @@ def test_wan_model(self):
278278 )
279279 assert dummy_output .shape == hidden_states_shape
280280
281- def test_get_qt_provider (self ):
281+ @patch ('maxdiffusion.pipelines.wan.wan_pipeline.qwix.QtRule' )
282+ def test_get_qt_provider (self , mock_qt_rule ):
282283 """
283284 Tests the provider logic for all config branches.
284285 """
@@ -293,9 +294,14 @@ def test_get_qt_provider(self):
293294 config_int8 .quantization = "int8"
294295 provider_int8 : QtProvider = WanPipeline .get_qt_provider (config_int8 )
295296 self .assertIsNotNone (provider_int8 )
296- self .assertEqual (provider_int8 ._rules [0 ].weight_qtype , jnp .int8 )
297+ mock_qt_rule .assert_called_once_with (
298+ module_path = '.*' ,
299+ weight_qtype = jnp .int8 ,
300+ act_qtype = jnp .int8
301+ )
297302
298303 # Case 3: Quantization enabled, type 'fp8'
304+ mock_qt_rule .reset_mock ()
299305 config_fp8 = Mock (spec = HyperParameters )
300306 config_fp8 .use_qwix_quantization = True
301307 config_fp8 .quantization = "fp8"
@@ -304,6 +310,7 @@ def test_get_qt_provider(self):
304310 self .assertEqual (provider_fp8 .rules [0 ].kwargs ["weight_qtype" ], jnp .float8_e4m3fn )
305311
306312 # Case 4: Quantization enabled, type 'fp8_full'
313+ mock_qt_rule .reset_mock ()
307314 config_fp8_full = Mock (spec = HyperParameters )
308315 config_fp8_full .use_qwix_quantization = True
309316 config_fp8_full .quantization = "fp8_full"
@@ -334,6 +341,8 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
334341 mock_model = Mock (spec = WanModel )
335342 mock_pipeline = Mock ()
336343 mock_mesh = Mock ()
344+ mock_mesh .__enter__ = Mock (return_value = None )
345+ mock_mesh .__exit__ = Mock (return_value = None )
337346
338347 # Mock the return values of dependencies
339348 mock_get_dummy_inputs .return_value = (Mock (), Mock (), Mock ())
0 commit comments