1313 See the License for the specific language governing permissions and
1414 limitations under the License.
1515 """
16+
1617from qwix import QtProvider
1718import os
1819import jax
@@ -290,7 +291,7 @@ def test_get_qt_provider(self):
290291 config_int8 = Mock (spec = HyperParameters )
291292 config_int8 .use_qwix_quantization = True
292293 config_int8 .quantization = "int8"
293- provider_int8 :QtProvider = WanPipeline .get_qt_provider (config_int8 )
294+ provider_int8 : QtProvider = WanPipeline .get_qt_provider (config_int8 )
294295 self .assertIsNotNone (provider_int8 )
295296 self .assertEqual (provider_int8 ._rules [0 ].weight_qtype , jnp .int8 )
296297
@@ -300,7 +301,7 @@ def test_get_qt_provider(self):
300301 config_fp8 .quantization = "fp8"
301302 provider_fp8 = WanPipeline .get_qt_provider (config_fp8 )
302303 self .assertIsNotNone (provider_fp8 )
303- self .assertEqual (provider_fp8 .rules [0 ].kwargs [' weight_qtype' ], jnp .float8_e4m3fn )
304+ self .assertEqual (provider_fp8 .rules [0 ].kwargs [" weight_qtype" ], jnp .float8_e4m3fn )
304305
305306 # Case 4: Quantization enabled, type 'fp8_full'
306307 config_fp8_full = Mock (spec = HyperParameters )
@@ -309,7 +310,7 @@ def test_get_qt_provider(self):
309310 config_fp8_full .quantization_calibration_method = "absmax"
310311 provider_fp8_full = WanPipeline .get_qt_provider (config_fp8_full )
311312 self .assertIsNotNone (provider_fp8_full )
312- self .assertEqual (provider_fp8_full .rules [0 ].kwargs [' bwd_qtype' ], jnp .float8_e5m2 )
313+ self .assertEqual (provider_fp8_full .rules [0 ].kwargs [" bwd_qtype" ], jnp .float8_e5m2 )
313314
314315 # Case 5: Invalid quantization type
315316 config_invalid = Mock (spec = HyperParameters )
@@ -318,8 +319,8 @@ def test_get_qt_provider(self):
318319 self .assertIsNone (WanPipeline .get_qt_provider (config_invalid ))
319320
320321 # To test quantize_transformer, we patch its external dependencies
321- @patch (' maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model' )
322- @patch (' maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs' )
322+ @patch (" maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model" )
323+ @patch (" maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs" )
323324 def test_quantize_transformer_enabled (self , mock_get_dummy_inputs , mock_quantize_model ):
324325 """
325326 Tests that quantize_transformer calls qwix when quantization is enabled.
@@ -348,14 +349,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
348349 # Check that the model returned is the new quantized model
349350 self .assertIs (result , mock_quantized_model_obj )
350351
351- @patch (' maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model' )
352+ @patch (" maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model" )
352353 def test_quantize_transformer_disabled (self , mock_quantize_model ):
353354 """
354355 Tests that quantize_transformer is skipped when quantization is disabled.
355356 """
356357 # Setup Mocks
357358 mock_config = Mock (spec = HyperParameters )
358- mock_config .use_qwix_quantization = False # Main condition for this test
359+ mock_config .use_qwix_quantization = False # Main condition for this test
359360
360361 mock_model = Mock (spec = WanModel )
361362
0 commit comments