Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions src/maxdiffusion/tests/wan_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from qwix import QtProvider
import os
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -292,7 +290,7 @@ def test_get_qt_provider(self, mock_qt_rule):
config_int8 = Mock(spec=HyperParameters)
config_int8.use_qwix_quantization = True
config_int8.quantization = "int8"
provider_int8: QtProvider = WanPipeline.get_qt_provider(config_int8)
provider_int8 = WanPipeline.get_qt_provider(config_int8)
self.assertIsNotNone(provider_int8)
mock_qt_rule.assert_called_once_with(
module_path='.*',
Expand All @@ -307,7 +305,11 @@ def test_get_qt_provider(self, mock_qt_rule):
config_fp8.quantization = "fp8"
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
self.assertIsNotNone(provider_fp8)
self.assertEqual(provider_fp8.rules[0].kwargs["weight_qtype"], jnp.float8_e4m3fn)
mock_qt_rule.assert_called_once_with(
module_path='.*',
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn
)

# Case 4: Quantization enabled, type 'fp8_full'
mock_qt_rule.reset_mock()
Expand All @@ -317,7 +319,17 @@ def test_get_qt_provider(self, mock_qt_rule):
config_fp8_full.quantization_calibration_method = "absmax"
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
self.assertIsNotNone(provider_fp8_full)
self.assertEqual(provider_fp8_full.rules[0].kwargs["bwd_qtype"], jnp.float8_e5m2)
mock_qt_rule.assert_called_once_with(
module_path='.*', # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method = config_fp8_full.quantization_calibration_method,
act_calibration_method = config_fp8_full.quantization_calibration_method,
bwd_calibration_method = config_fp8_full.quantization_calibration_method,
)

# Case 5: Invalid quantization type
config_invalid = Mock(spec=HyperParameters)
Expand All @@ -326,8 +338,8 @@ def test_get_qt_provider(self, mock_qt_rule):
self.assertIsNone(WanPipeline.get_qt_provider(config_invalid))

# To test quantize_transformer, we patch its external dependencies
@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
@patch("maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs")
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
@patch('maxdiffusion.pipelines.wan.wan_pipeline.get_dummy_wan_inputs')
def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize_model):
"""
Tests that quantize_transformer calls qwix when quantization is enabled.
Expand Down Expand Up @@ -358,14 +370,14 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
# Check that the model returned is the new quantized model
self.assertIs(result, mock_quantized_model_obj)

@patch("maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model")
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.quantize_model')
def test_quantize_transformer_disabled(self, mock_quantize_model):
"""
Tests that quantize_transformer is skipped when quantization is disabled.
"""
# Setup Mocks
mock_config = Mock(spec=HyperParameters)
mock_config.use_qwix_quantization = False # Main condition for this test
mock_config.use_qwix_quantization = False # Main condition for this test

mock_model = Mock(spec=WanModel)

Expand Down
Loading