1919import jax .numpy as jnp
2020import pytest
2121import unittest
22- from unittest .mock import Mock , patch
22+ from unittest .mock import Mock , patch , call
2323from absl .testing import absltest
2424from flax import nnx
2525from jax .sharding import Mesh
@@ -291,28 +291,43 @@ def test_get_qt_provider(self, mock_qt_rule):
291291 config_int8 = Mock (spec = HyperParameters )
292292 config_int8 .use_qwix_quantization = True
293293 config_int8 .quantization = "int8"
294+ config_int8 .qwix_module_path = ".*"
294295 provider_int8 = WanPipeline .get_qt_provider (config_int8 )
295296 self .assertIsNotNone (provider_int8 )
296- mock_qt_rule .assert_called_once_with (module_path = ".*" , weight_qtype = jnp .int8 , act_qtype = jnp .int8 )
297+ mock_qt_rule .assert_called_once_with (module_path = ".*" , weight_qtype = jnp .int8 , act_qtype = jnp .int8 , op_names = ( "dot_general" , "einsum" , "conv_general_dilated" ) )
297298
298299 # Case 3: Quantization enabled, type 'fp8'
299300 mock_qt_rule .reset_mock ()
300301 config_fp8 = Mock (spec = HyperParameters )
301302 config_fp8 .use_qwix_quantization = True
302303 config_fp8 .quantization = "fp8"
304+ config_int8 .qwix_module_path = ".*"
303305 provider_fp8 = WanPipeline .get_qt_provider (config_fp8 )
304306 self .assertIsNotNone (provider_fp8 )
305- mock_qt_rule .assert_called_once_with (module_path = ".*" , weight_qtype = jnp .float8_e4m3fn , act_qtype = jnp .float8_e4m3fn )
307+ mock_qt_rule .assert_called_once_with (module_path = ".*" , weight_qtype = jnp .float8_e4m3fn , act_qtype = jnp .float8_e4m3fn , op_names = ( "dot_general" , "einsum" , "conv_general_dilated" ) )
306308
307309 # Case 4: Quantization enabled, type 'fp8_full'
308310 mock_qt_rule .reset_mock ()
309311 config_fp8_full = Mock (spec = HyperParameters )
310312 config_fp8_full .use_qwix_quantization = True
311313 config_fp8_full .quantization = "fp8_full"
312314 config_fp8_full .quantization_calibration_method = "absmax"
315+ config_int8 .qwix_module_path = ".*"
313316 provider_fp8_full = WanPipeline .get_qt_provider (config_fp8_full )
314317 self .assertIsNotNone (provider_fp8_full )
315- mock_qt_rule .assert_called_once_with (
318+ expected_calls = [
319+ call (module_path = ".*" , # Apply to all modules
320+ weight_qtype = jnp .float8_e4m3fn ,
321+ act_qtype = jnp .float8_e4m3fn ,
322+ bwd_qtype = jnp .float8_e5m2 ,
323+ bwd_use_original_residuals = True ,
324+ disable_channelwise_axes = True , # per_tensor calibration
325+ weight_calibration_method = config_fp8_full .quantization_calibration_method ,
326+ act_calibration_method = config_fp8_full .quantization_calibration_method ,
327+ bwd_calibration_method = config_fp8_full .quantization_calibration_method ,
328+ op_names = ("dot_general" ,"einsum" ),
329+ ),
330+ call (
316331 module_path = ".*" , # Apply to all modules
317332 weight_qtype = jnp .float8_e4m3fn ,
318333 act_qtype = jnp .float8_e4m3fn ,
@@ -322,7 +337,10 @@ def test_get_qt_provider(self, mock_qt_rule):
322337 weight_calibration_method = config_fp8_full .quantization_calibration_method ,
323338 act_calibration_method = config_fp8_full .quantization_calibration_method ,
324339 bwd_calibration_method = config_fp8_full .quantization_calibration_method ,
325- )
340+ op_names = ("conv_general_dilated" ),
341+ )
342+ ]
343+ mock_qt_rule .assert_has_calls (expected_calls , any_order = True )
326344
327345 # Case 5: Invalid quantization type
328346 config_invalid = Mock (spec = HyperParameters )
@@ -341,6 +359,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
341359 mock_config = Mock (spec = HyperParameters )
342360 mock_config .use_qwix_quantization = True
343361 mock_config .quantization = "fp8_full"
362+ mock_config .qwix_module_path = ".*"
344363 mock_config .per_device_batch_size = 1
345364
346365 mock_model = Mock (spec = WanModel )
0 commit comments