Skip to content

Commit 1054ee1

Browse files
committed
fix test
1 parent e12aec8 commit 1054ee1

1 file changed

Lines changed: 24 additions & 5 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import jax.numpy as jnp
2020
import pytest
2121
import unittest
22-
from unittest.mock import Mock, patch
22+
from unittest.mock import Mock, patch, call
2323
from absl.testing import absltest
2424
from flax import nnx
2525
from jax.sharding import Mesh
@@ -291,28 +291,43 @@ def test_get_qt_provider(self, mock_qt_rule):
291291
config_int8 = Mock(spec=HyperParameters)
292292
config_int8.use_qwix_quantization = True
293293
config_int8.quantization = "int8"
294+
config_int8.qwix_module_path = ".*"
294295
provider_int8 = WanPipeline.get_qt_provider(config_int8)
295296
self.assertIsNotNone(provider_int8)
296-
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8)
297+
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8, op_names=("dot_general","einsum", "conv_general_dilated"))
297298

298299
# Case 3: Quantization enabled, type 'fp8'
299300
mock_qt_rule.reset_mock()
300301
config_fp8 = Mock(spec=HyperParameters)
301302
config_fp8.use_qwix_quantization = True
302303
config_fp8.quantization = "fp8"
304+
config_int8.qwix_module_path = ".*"
303305
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
304306
self.assertIsNotNone(provider_fp8)
305-
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn)
307+
mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, op_names=("dot_general","einsum", "conv_general_dilated"))
306308

307309
# Case 4: Quantization enabled, type 'fp8_full'
308310
mock_qt_rule.reset_mock()
309311
config_fp8_full = Mock(spec=HyperParameters)
310312
config_fp8_full.use_qwix_quantization = True
311313
config_fp8_full.quantization = "fp8_full"
312314
config_fp8_full.quantization_calibration_method = "absmax"
315+
config_int8.qwix_module_path = ".*"
313316
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
314317
self.assertIsNotNone(provider_fp8_full)
315-
mock_qt_rule.assert_called_once_with(
318+
expected_calls = [
319+
call(module_path=".*", # Apply to all modules
320+
weight_qtype=jnp.float8_e4m3fn,
321+
act_qtype=jnp.float8_e4m3fn,
322+
bwd_qtype=jnp.float8_e5m2,
323+
bwd_use_original_residuals=True,
324+
disable_channelwise_axes=True, # per_tensor calibration
325+
weight_calibration_method=config_fp8_full.quantization_calibration_method,
326+
act_calibration_method=config_fp8_full.quantization_calibration_method,
327+
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
328+
op_names=("dot_general","einsum"),
329+
),
330+
call(
316331
module_path=".*", # Apply to all modules
317332
weight_qtype=jnp.float8_e4m3fn,
318333
act_qtype=jnp.float8_e4m3fn,
@@ -322,7 +337,10 @@ def test_get_qt_provider(self, mock_qt_rule):
322337
weight_calibration_method=config_fp8_full.quantization_calibration_method,
323338
act_calibration_method=config_fp8_full.quantization_calibration_method,
324339
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
325-
)
340+
op_names=("conv_general_dilated"),
341+
)
342+
]
343+
mock_qt_rule.assert_has_calls(expected_calls, any_order=True)
326344

327345
# Case 5: Invalid quantization type
328346
config_invalid = Mock(spec=HyperParameters)
@@ -341,6 +359,7 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
341359
mock_config = Mock(spec=HyperParameters)
342360
mock_config.use_qwix_quantization = True
343361
mock_config.quantization = "fp8_full"
362+
mock_config.qwix_module_path = ".*"
344363
mock_config.per_device_batch_size = 1
345364

346365
mock_model = Mock(spec=WanModel)

0 commit comments

Comments
 (0)