1313 See the License for the specific language governing permissions and
1414 limitations under the License.
1515 """
16-
17- from qwix import QtProvider
1816import os
1917import jax
2018import jax .numpy as jnp
@@ -292,7 +290,7 @@ def test_get_qt_provider(self, mock_qt_rule):
292290 config_int8 = Mock (spec = HyperParameters )
293291 config_int8 .use_qwix_quantization = True
294292 config_int8 .quantization = "int8"
295- provider_int8 : QtProvider = WanPipeline .get_qt_provider (config_int8 )
293+ provider_int8 = WanPipeline .get_qt_provider (config_int8 )
296294 self .assertIsNotNone (provider_int8 )
297295 mock_qt_rule .assert_called_once_with (
298296 module_path = '.*' ,
@@ -307,7 +305,11 @@ def test_get_qt_provider(self, mock_qt_rule):
307305 config_fp8 .quantization = "fp8"
308306 provider_fp8 = WanPipeline .get_qt_provider (config_fp8 )
309307 self .assertIsNotNone (provider_fp8 )
310- self .assertEqual (provider_fp8 .rules [0 ].kwargs ["weight_qtype" ], jnp .float8_e4m3fn )
308+ mock_qt_rule .assert_called_once_with (
309+ module_path = '.*' ,
310+ weight_qtype = jnp .float8_e4m3fn ,
311+ act_qtype = jnp .float8_e4m3fn
312+ )
311313
312314 # Case 4: Quantization enabled, type 'fp8_full'
313315 mock_qt_rule .reset_mock ()
@@ -317,7 +319,17 @@ def test_get_qt_provider(self, mock_qt_rule):
317319 config_fp8_full .quantization_calibration_method = "absmax"
318320 provider_fp8_full = WanPipeline .get_qt_provider (config_fp8_full )
319321 self .assertIsNotNone (provider_fp8_full )
320- self .assertEqual (provider_fp8_full .rules [0 ].kwargs ["bwd_qtype" ], jnp .float8_e5m2 )
322+ mock_qt_rule .assert_called_once_with (
323+ module_path = '.*' , # Apply to all modules
324+ weight_qtype = jnp .float8_e4m3fn ,
325+ act_qtype = jnp .float8_e4m3fn ,
326+ bwd_qtype = jnp .float8_e5m2 ,
327+ bwd_use_original_residuals = True ,
328+ disable_channelwise_axes = True , # per_tensor calibration
329+ weight_calibration_method = config_fp8_full .quantization_calibration_method ,
330+ act_calibration_method = config_fp8_full .quantization_calibration_method ,
331+ bwd_calibration_method = config_fp8_full .quantization_calibration_method ,
332+ )
321333
322334 # Case 5: Invalid quantization type
323335 config_invalid = Mock (spec = HyperParameters )
@@ -326,8 +338,8 @@ def test_get_qt_provider(self, mock_qt_rule):
326338 self .assertIsNone (WanPipeline .get_qt_provider (config_invalid ))
327339
328340 # To test quantize_transformer, we patch its external dependencies
329- @patch (" maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model" )
330- @patch (" maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs" )
341+ @patch (' maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model' )
342+ @patch (' maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs' )
331343 def test_quantize_transformer_enabled (self , mock_get_dummy_inputs , mock_quantize_model ):
332344 """
333345 Tests that quantize_transformer calls qwix when quantization is enabled.
@@ -358,14 +370,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
358370 # Check that the model returned is the new quantized model
359371 self .assertIs (result , mock_quantized_model_obj )
360372
361- @patch (" maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model" )
373+ @patch (' maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model' )
362374 def test_quantize_transformer_disabled (self , mock_quantize_model ):
363375 """
364376 Tests that quantize_transformer is skipped when quantization is disabled.
365377 """
366378 # Setup Mocks
367379 mock_config = Mock (spec = HyperParameters )
368- mock_config .use_qwix_quantization = False # Main condition for this test
380+ mock_config .use_qwix_quantization = False # Main condition for this test
369381
370382 mock_model = Mock (spec = WanModel )
371383
0 commit comments