Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions src/maxdiffusion/tests/wan_transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from qwix import QtProvider
import os
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -277,7 +276,8 @@ def test_wan_model(self):
)
assert dummy_output.shape == hidden_states_shape

def test_get_qt_provider(self):
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.QtRule')
def test_get_qt_provider(self, mock_qt_rule):
"""
Tests the provider logic for all config branches.
"""
Expand All @@ -290,26 +290,46 @@ def test_get_qt_provider(self):
config_int8 = Mock(spec=HyperParameters)
config_int8.use_qwix_quantization = True
config_int8.quantization = "int8"
provider_int8:QtProvider = WanPipeline.get_qt_provider(config_int8)
provider_int8 = WanPipeline.get_qt_provider(config_int8)
self.assertIsNotNone(provider_int8)
self.assertEqual(provider_int8._rules[0].weight_qtype, jnp.int8)
mock_qt_rule.assert_called_once_with(
module_path='.*',
weight_qtype=jnp.int8,
act_qtype=jnp.int8
)

# Case 3: Quantization enabled, type 'fp8'
mock_qt_rule.reset_mock()
config_fp8 = Mock(spec=HyperParameters)
config_fp8.use_qwix_quantization = True
config_fp8.quantization = "fp8"
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
self.assertIsNotNone(provider_fp8)
self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn)
mock_qt_rule.assert_called_once_with(
module_path='.*',
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn
)

# Case 4: Quantization enabled, type 'fp8_full'
mock_qt_rule.reset_mock()
config_fp8_full = Mock(spec=HyperParameters)
config_fp8_full.use_qwix_quantization = True
config_fp8_full.quantization = "fp8_full"
config_fp8_full.quantization_calibration_method = "absmax"
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
self.assertIsNotNone(provider_fp8_full)
self.assertEqual(provider_fp8_full.rules[0].kwargs['bwd_qtype'], jnp.float8_e5m2)
mock_qt_rule.assert_called_once_with(
module_path='.*', # Apply to all modules
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method = config_fp8_full.quantization_calibration_method,
act_calibration_method = config_fp8_full.quantization_calibration_method,
bwd_calibration_method = config_fp8_full.quantization_calibration_method,
)

# Case 5: Invalid quantization type
config_invalid = Mock(spec=HyperParameters)
Expand All @@ -333,6 +353,8 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
mock_model = Mock(spec=WanModel)
mock_pipeline = Mock()
mock_mesh = Mock()
mock_mesh.__enter__ = Mock(return_value=None)
mock_mesh.__exit__ = Mock(return_value=None)

# Mock the return values of dependencies
mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock())
Expand Down
Loading