|
19 | 19 | import jax.numpy as jnp |
20 | 20 | import pytest |
21 | 21 | import unittest |
| 22 | +from unittest.mock import Mock, patch, MagicMock |
22 | 23 | from absl.testing import absltest |
23 | 24 | from flax import nnx |
24 | 25 | from jax.sharding import Mesh |
|
34 | 35 | from ..models.embeddings_flax import NNXTimestepEmbedding, NNXPixArtAlphaTextProjection |
35 | 36 | from ..models.normalization_flax import FP32LayerNorm |
36 | 37 | from ..models.attention_flax import FlaxWanAttention |
| 38 | +from maxdiffusion.pyconfig import HyperParameters |
| 39 | +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline |
| 40 | + |
37 | 41 |
|
38 | 42 | IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" |
39 | 43 |
|
@@ -272,6 +276,96 @@ def test_wan_model(self): |
272 | 276 | hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states |
273 | 277 | ) |
274 | 278 | assert dummy_output.shape == hidden_states_shape |
| 279 | + |
| 280 | + def test_get_qt_provider(self): |
| 281 | + """ |
| 282 | + Tests the provider logic for all config branches. |
| 283 | + """ |
| 284 | + # Case 1: Quantization disabled |
| 285 | + config_disabled = Mock(spec=HyperParameters) |
| 286 | + config_disabled.use_qwix_quantization = False |
| 287 | + self.assertIsNone(WanPipeline.get_qt_provider(config_disabled)) |
| 288 | + |
| 289 | + # Case 2: Quantization enabled, type 'int8' |
| 290 | + config_int8 = Mock(spec=HyperParameters) |
| 291 | + config_int8.use_qwix_quantization = True |
| 292 | + config_int8.quantization = "int8" |
| 293 | + provider_int8 = WanPipeline.get_qt_provider(config_int8) |
| 294 | + self.assertIsNotNone(provider_int8) |
| 295 | + self.assertEqual(provider_int8.rules[0].kwargs['weight_qtype'], jnp.int8) |
| 296 | + |
| 297 | + # Case 3: Quantization enabled, type 'fp8' |
| 298 | + config_fp8 = Mock(spec=HyperParameters) |
| 299 | + config_fp8.use_qwix_quantization = True |
| 300 | + config_fp8.quantization = "fp8" |
| 301 | + provider_fp8 = WanPipeline.get_qt_provider(config_fp8) |
| 302 | + self.assertIsNotNone(provider_fp8) |
| 303 | + self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn) |
| 304 | + |
| 305 | + # Case 4: Quantization enabled, type 'fp8_full' |
| 306 | + config_fp8_full = Mock(spec=HyperParameters) |
| 307 | + config_fp8_full.use_qwix_quantization = True |
| 308 | + config_fp8_full.quantization = "fp8_full" |
| 309 | + config_fp8_full.quantization_calibration_method = "absmax" |
| 310 | + provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) |
| 311 | + self.assertIsNotNone(provider_fp8_full) |
| 312 | + self.assertEqual(provider_fp8_full.rules[0].kwargs['bwd_qtype'], jnp.float8_e5m2) |
| 313 | + |
| 314 | + # Case 5: Invalid quantization type |
| 315 | + config_invalid = Mock(spec=HyperParameters) |
| 316 | + config_invalid.use_qwix_quantization = True |
| 317 | + config_invalid.quantization = "invalid_type" |
| 318 | + self.assertIsNone(WanPipeline.get_qt_provider(config_invalid)) |
| 319 | + |
| 320 | + # To test quantize_transformer, we patch its external dependencies |
| 321 | + @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') |
| 322 | + @patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs') |
| 323 | + def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model): |
| 324 | + """ |
| 325 | + Tests that quantize_transformer calls qwix when quantization is enabled. |
| 326 | + """ |
| 327 | + # Setup Mocks |
| 328 | + mock_config = Mock(spec=HyperParameters) |
| 329 | + mock_config.use_qwix_quantization = True |
| 330 | + mock_config.quantization = "fp8_full" |
| 331 | + mock_config.per_device_batch_size = 1 |
| 332 | + |
| 333 | + mock_model = Mock(spec=WanModel) |
| 334 | + mock_pipeline = Mock() |
| 335 | + mock_mesh = Mock() |
| 336 | + |
| 337 | + # Mock the return values of dependencies |
| 338 | + mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock()) |
| 339 | + mock_quantized_model_obj = Mock(spec=WanModel) |
| 340 | + mock_quantize_model.return_value = mock_quantized_model_obj |
| 341 | + |
| 342 | + # Call the method under test |
| 343 | + result = WanPipeline.quantize_transformer(mock_config, mock_model, mock_pipeline, mock_mesh) |
| 344 | + |
| 345 | + # Assertions |
| 346 | + mock_get_dummy_inputs.assert_called_once() |
| 347 | + mock_quantize_model.assert_called_once() |
| 348 | + # Check that the model returned is the new quantized model |
| 349 | + self.assertIs(result, mock_quantized_model_obj) |
| 350 | + |
| 351 | + @patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model') |
| 352 | + def test_quantize_transformer_disabled(self, mock_quantize_model): |
| 353 | + """ |
| 354 | + Tests that quantize_transformer is skipped when quantization is disabled. |
| 355 | + """ |
| 356 | + # Setup Mocks |
| 357 | + mock_config = Mock(spec=HyperParameters) |
| 358 | + mock_config.use_qwix_quantization = False # Main condition for this test |
| 359 | + |
| 360 | + mock_model = Mock(spec=WanModel) |
| 361 | + |
| 362 | + # Call the method under test |
| 363 | + result = WanPipeline.quantize_transformer(mock_config, mock_model, Mock(), Mock()) |
| 364 | + |
| 365 | + # Assertions |
| 366 | + mock_quantize_model.assert_not_called() |
| 367 | + # Check that the model returned is the original model instance |
| 368 | + self.assertIs(result, mock_model) |
275 | 369 |
|
276 | 370 |
|
277 | 371 | if __name__ == "__main__": |
|
0 commit comments