1313 See the License for the specific language governing permissions and
1414 limitations under the License.
1515 """
16- from qwix import QtProvider
1716import os
1817import jax
1918import jax .numpy as jnp
@@ -277,7 +276,8 @@ def test_wan_model(self):
277276 )
278277 assert dummy_output .shape == hidden_states_shape
279278
280- def test_get_qt_provider (self ):
279+ @patch ('maxdiffusion.pipelines.wan.wan_pipeline.qwix.QtRule' )
280+ def test_get_qt_provider (self , mock_qt_rule ):
281281 """
282282 Tests the provider logic for all config branches.
283283 """
@@ -290,26 +290,46 @@ def test_get_qt_provider(self):
290290 config_int8 = Mock (spec = HyperParameters )
291291 config_int8 .use_qwix_quantization = True
292292 config_int8 .quantization = "int8"
293- provider_int8 : QtProvider = WanPipeline .get_qt_provider (config_int8 )
293+ provider_int8 = WanPipeline .get_qt_provider (config_int8 )
294294 self .assertIsNotNone (provider_int8 )
295- self .assertEqual (provider_int8 ._rules [0 ].weight_qtype , jnp .int8 )
295+ mock_qt_rule .assert_called_once_with (
296+ module_path = '.*' ,
297+ weight_qtype = jnp .int8 ,
298+ act_qtype = jnp .int8
299+ )
296300
297301 # Case 3: Quantization enabled, type 'fp8'
302+ mock_qt_rule .reset_mock ()
298303 config_fp8 = Mock (spec = HyperParameters )
299304 config_fp8 .use_qwix_quantization = True
300305 config_fp8 .quantization = "fp8"
301306 provider_fp8 = WanPipeline .get_qt_provider (config_fp8 )
302307 self .assertIsNotNone (provider_fp8 )
303- self .assertEqual (provider_fp8 .rules [0 ].kwargs ['weight_qtype' ], jnp .float8_e4m3fn )
308+ mock_qt_rule .assert_called_once_with (
309+ module_path = '.*' ,
310+ weight_qtype = jnp .float8_e4m3fn ,
311+ act_qtype = jnp .float8_e4m3fn
312+ )
304313
305314 # Case 4: Quantization enabled, type 'fp8_full'
315+ mock_qt_rule .reset_mock ()
306316 config_fp8_full = Mock (spec = HyperParameters )
307317 config_fp8_full .use_qwix_quantization = True
308318 config_fp8_full .quantization = "fp8_full"
309319 config_fp8_full .quantization_calibration_method = "absmax"
310320 provider_fp8_full = WanPipeline .get_qt_provider (config_fp8_full )
311321 self .assertIsNotNone (provider_fp8_full )
312- self .assertEqual (provider_fp8_full .rules [0 ].kwargs ['bwd_qtype' ], jnp .float8_e5m2 )
322+ mock_qt_rule .assert_called_once_with (
323+ module_path = '.*' , # Apply to all modules
324+ weight_qtype = jnp .float8_e4m3fn ,
325+ act_qtype = jnp .float8_e4m3fn ,
326+ bwd_qtype = jnp .float8_e5m2 ,
327+ bwd_use_original_residuals = True ,
328+ disable_channelwise_axes = True , # per_tensor calibration
329+ weight_calibration_method = config_fp8_full .quantization_calibration_method ,
330+ act_calibration_method = config_fp8_full .quantization_calibration_method ,
331+ bwd_calibration_method = config_fp8_full .quantization_calibration_method ,
332+ )
313333
314334 # Case 5: Invalid quantization type
315335 config_invalid = Mock (spec = HyperParameters )
@@ -333,6 +353,8 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
333353 mock_model = Mock (spec = WanModel )
334354 mock_pipeline = Mock ()
335355 mock_mesh = Mock ()
356+ mock_mesh .__enter__ = Mock (return_value = None )
357+ mock_mesh .__exit__ = Mock (return_value = None )
336358
337359 # Mock the return values of dependencies
338360 mock_get_dummy_inputs .return_value = (Mock (), Mock (), Mock ())
0 commit comments