Skip to content

Commit 2d4d611

Browse files
committed
fixe unit test
1 parent d93e0ea commit 2d4d611

1 file changed

Lines changed: 9 additions & 7 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,9 @@ def create_real_rule_instance(*args, **kwargs):
332332
config_fp8_full = Mock(spec=HyperParameters)
333333
config_fp8_full.use_qwix_quantization = True
334334
config_fp8_full.quantization = "fp8_full"
335-
config_fp8_full.quantization_calibration_method = "absmax"
335+
config_fp8_full.weight_quantization_calibration_method = "absmax"
336+
config_fp8_full.act_quantization_calibration_method = "absmax"
337+
config_fp8_full.bwd_quantization_calibration_method = "absmax"
336338
config_fp8_full.qwix_module_path = ".*"
337339
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
338340
self.assertIsNotNone(provider_fp8_full)
@@ -343,9 +345,9 @@ def create_real_rule_instance(*args, **kwargs):
343345
act_qtype=jnp.float8_e4m3fn,
344346
bwd_qtype=jnp.float8_e5m2,
345347
disable_channelwise_axes=True, # per_tensor calibration
346-
weight_calibration_method=config_fp8_full.quantization_calibration_method,
347-
act_calibration_method=config_fp8_full.quantization_calibration_method,
348-
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
348+
weight_calibration_method=config_fp8_full.weight_quantization_calibration_method,
349+
act_calibration_method=config_fp8_full.act_quantization_calibration_method,
350+
bwd_calibration_method=config_fp8_full.bwd_quantization_calibration_method,
349351
op_names=("dot_general", "einsum"),
350352
),
351353
call(
@@ -354,9 +356,9 @@ def create_real_rule_instance(*args, **kwargs):
354356
act_qtype=jnp.float8_e4m3fn,
355357
bwd_qtype=jnp.float8_e4m3fn,
356358
disable_channelwise_axes=True, # per_tensor calibration
357-
weight_calibration_method=config_fp8_full.quantization_calibration_method,
358-
act_calibration_method=config_fp8_full.quantization_calibration_method,
359-
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
359+
weight_calibration_method=config_fp8_full.weight_quantization_calibration_method,
360+
act_calibration_method=config_fp8_full.act_quantization_calibration_method,
361+
bwd_calibration_method=config_fp8_full.bwd_quantization_calibration_method,
360362
op_names=("conv_general_dilated"),
361363
),
362364
]

0 commit comments

Comments
 (0)