1313 See the License for the specific language governing permissions and
1414 limitations under the License.
1515 """
16+
17+ from qwix import QtProvider
1618import os
1719import jax
1820import jax .numpy as jnp
@@ -290,7 +292,7 @@ def test_get_qt_provider(self, mock_qt_rule):
290292 config_int8 = Mock (spec = HyperParameters )
291293 config_int8 .use_qwix_quantization = True
292294 config_int8 .quantization = "int8"
293- provider_int8 = WanPipeline .get_qt_provider (config_int8 )
295+ provider_int8 : QtProvider = WanPipeline .get_qt_provider (config_int8 )
294296 self .assertIsNotNone (provider_int8 )
295297 mock_qt_rule .assert_called_once_with (
296298 module_path = '.*' ,
@@ -305,11 +307,7 @@ def test_get_qt_provider(self, mock_qt_rule):
305307 config_fp8 .quantization = "fp8"
306308 provider_fp8 = WanPipeline .get_qt_provider (config_fp8 )
307309 self .assertIsNotNone (provider_fp8 )
308- mock_qt_rule .assert_called_once_with (
309- module_path = '.*' ,
310- weight_qtype = jnp .float8_e4m3fn ,
311- act_qtype = jnp .float8_e4m3fn
312- )
310+ self .assertEqual (provider_fp8 .rules [0 ].kwargs ["weight_qtype" ], jnp .float8_e4m3fn )
313311
314312 # Case 4: Quantization enabled, type 'fp8_full'
315313 mock_qt_rule .reset_mock ()
@@ -319,17 +317,7 @@ def test_get_qt_provider(self, mock_qt_rule):
319317 config_fp8_full .quantization_calibration_method = "absmax"
320318 provider_fp8_full = WanPipeline .get_qt_provider (config_fp8_full )
321319 self .assertIsNotNone (provider_fp8_full )
322- mock_qt_rule .assert_called_once_with (
323- module_path = '.*' , # Apply to all modules
324- weight_qtype = jnp .float8_e4m3fn ,
325- act_qtype = jnp .float8_e4m3fn ,
326- bwd_qtype = jnp .float8_e5m2 ,
327- bwd_use_original_residuals = True ,
328- disable_channelwise_axes = True , # per_tensor calibration
329- weight_calibration_method = config_fp8_full .quantization_calibration_method ,
330- act_calibration_method = config_fp8_full .quantization_calibration_method ,
331- bwd_calibration_method = config_fp8_full .quantization_calibration_method ,
332- )
320+ self .assertEqual (provider_fp8_full .rules [0 ].kwargs ["bwd_qtype" ], jnp .float8_e5m2 )
333321
334322 # Case 5: Invalid quantization type
335323 config_invalid = Mock (spec = HyperParameters )
@@ -338,8 +326,8 @@ def test_get_qt_provider(self, mock_qt_rule):
338326 self .assertIsNone (WanPipeline .get_qt_provider (config_invalid ))
339327
340328 # To test quantize_transformer, we patch its external dependencies
341- @patch (' maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model' )
342- @patch (' maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs' )
329+ @patch (" maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model" )
330+ @patch (" maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs" )
343331 def test_quantize_transformer_enabled (self , mock_get_dummy_inputs , mock_quantize_model ):
344332 """
345333 Tests that quantize_transformer calls qwix when quantization is enabled.
@@ -370,14 +358,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
370358 # Check that the model returned is the new quantized model
371359 self .assertIs (result , mock_quantized_model_obj )
372360
373- @patch (' maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model' )
361+ @patch (" maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model" )
374362 def test_quantize_transformer_disabled (self , mock_quantize_model ):
375363 """
376364 Tests that quantize_transformer is skipped when quantization is disabled.
377365 """
378366 # Setup Mocks
379367 mock_config = Mock (spec = HyperParameters )
380- mock_config .use_qwix_quantization = False # Main condition for this test
368+ mock_config .use_qwix_quantization = False # Main condition for this test
381369
382370 mock_model = Mock (spec = WanModel )
383371
0 commit comments