Skip to content

Commit 02c46eb

Browse files
committed
fix qwix unit test
1 parent 1c81058 commit 02c46eb

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
from ..models.attention_flax import FlaxWanAttention
3838
from maxdiffusion.pyconfig import HyperParameters
3939
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
40+
import qwix
41+
42+
RealQtRule = qwix.QtRule
4043

4144

4245
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
@@ -282,6 +285,10 @@ def test_get_qt_provider(self, mock_qt_rule):
282285
"""
283286
Tests the provider logic for all config branches.
284287
"""
288+
def create_real_rule_instance(*args, **kwargs):
289+
return RealQtRule(*args, **kwargs)
290+
mock_qt_rule.side_effect = create_real_rule_instance
291+
285292
# Case 1: Quantization disabled
286293
config_disabled = Mock(spec=HyperParameters)
287294
config_disabled.use_qwix_quantization = False
@@ -361,6 +368,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
361368
mock_config.quantization = "fp8_full"
362369
mock_config.qwix_module_path = ".*"
363370
mock_config.per_device_batch_size = 1
371+
mock_config.quantization_calibration_method = "absmax"
364372

365373
mock_model = Mock(spec=WanModel)
366374
mock_pipeline = Mock()

0 commit comments

Comments
 (0)