Skip to content

Commit ee53ee3

Browse files
authored
fix wan unit test bugs (#231)
* fix wan unit test bugs * line problems
1 parent f279995 commit ee53ee3

1 file changed

Lines changed: 28 additions & 6 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16-
from qwix import QtProvider
1716
import os
1817
import jax
1918
import jax.numpy as jnp
@@ -277,7 +276,8 @@ def test_wan_model(self):
277276
)
278277
assert dummy_output.shape == hidden_states_shape
279278

280-
def test_get_qt_provider(self):
279+
@patch('maxdiffusion.pipelines.wan.wan_pipeline.qwix.QtRule')
280+
def test_get_qt_provider(self, mock_qt_rule):
281281
"""
282282
Tests the provider logic for all config branches.
283283
"""
@@ -290,26 +290,46 @@ def test_get_qt_provider(self):
290290
config_int8 = Mock(spec=HyperParameters)
291291
config_int8.use_qwix_quantization = True
292292
config_int8.quantization = "int8"
293-
provider_int8:QtProvider = WanPipeline.get_qt_provider(config_int8)
293+
provider_int8 = WanPipeline.get_qt_provider(config_int8)
294294
self.assertIsNotNone(provider_int8)
295-
self.assertEqual(provider_int8._rules[0].weight_qtype, jnp.int8)
295+
mock_qt_rule.assert_called_once_with(
296+
module_path='.*',
297+
weight_qtype=jnp.int8,
298+
act_qtype=jnp.int8
299+
)
296300

297301
# Case 3: Quantization enabled, type 'fp8'
302+
mock_qt_rule.reset_mock()
298303
config_fp8 = Mock(spec=HyperParameters)
299304
config_fp8.use_qwix_quantization = True
300305
config_fp8.quantization = "fp8"
301306
provider_fp8 = WanPipeline.get_qt_provider(config_fp8)
302307
self.assertIsNotNone(provider_fp8)
303-
self.assertEqual(provider_fp8.rules[0].kwargs['weight_qtype'], jnp.float8_e4m3fn)
308+
mock_qt_rule.assert_called_once_with(
309+
module_path='.*',
310+
weight_qtype=jnp.float8_e4m3fn,
311+
act_qtype=jnp.float8_e4m3fn
312+
)
304313

305314
# Case 4: Quantization enabled, type 'fp8_full'
315+
mock_qt_rule.reset_mock()
306316
config_fp8_full = Mock(spec=HyperParameters)
307317
config_fp8_full.use_qwix_quantization = True
308318
config_fp8_full.quantization = "fp8_full"
309319
config_fp8_full.quantization_calibration_method = "absmax"
310320
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
311321
self.assertIsNotNone(provider_fp8_full)
312-
self.assertEqual(provider_fp8_full.rules[0].kwargs['bwd_qtype'], jnp.float8_e5m2)
322+
mock_qt_rule.assert_called_once_with(
323+
module_path='.*', # Apply to all modules
324+
weight_qtype=jnp.float8_e4m3fn,
325+
act_qtype=jnp.float8_e4m3fn,
326+
bwd_qtype=jnp.float8_e5m2,
327+
bwd_use_original_residuals=True,
328+
disable_channelwise_axes=True, # per_tensor calibration
329+
weight_calibration_method = config_fp8_full.quantization_calibration_method,
330+
act_calibration_method = config_fp8_full.quantization_calibration_method,
331+
bwd_calibration_method = config_fp8_full.quantization_calibration_method,
332+
)
313333

314334
# Case 5: Invalid quantization type
315335
config_invalid = Mock(spec=HyperParameters)
@@ -333,6 +353,8 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
333353
mock_model = Mock(spec=WanModel)
334354
mock_pipeline = Mock()
335355
mock_mesh = Mock()
356+
mock_mesh.__enter__ = Mock(return_value=None)
357+
mock_mesh.__exit__ = Mock(return_value=None)
336358

337359
# Mock the return values of dependencies
338360
mock_get_dummy_inputs.return_value = (Mock(), Mock(), Mock())

0 commit comments

Comments
 (0)