From 0815b9650243e2b19385e807ba70869c70a3a07b Mon Sep 17 00:00:00 2001 From: susanbao Date: Fri, 12 Sep 2025 23:22:42 +0000 Subject: [PATCH 1/2] fix error: conv_general_qt requires the same bwd_qtype as weight_qtype --- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index f2c3701e3..e93c9884d 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -265,7 +265,7 @@ def get_fp8_config(cls, quantization_calibration_method: str): module_path=".*", # Apply to all modules weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, + bwd_qtype=jnp.float8_e4m3fn, bwd_use_original_residuals=True, disable_channelwise_axes=True, # per_tensor calibration weight_calibration_method=quantization_calibration_method, From ffb0825017dd2e724f180797e02acf3a274031c6 Mon Sep 17 00:00:00 2001 From: susanbao Date: Sat, 13 Sep 2025 00:06:03 +0000 Subject: [PATCH 2/2] fix test error --- src/maxdiffusion/tests/wan_transformer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 84efa064e..26ea0f028 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -316,7 +316,7 @@ def test_get_qt_provider(self, mock_qt_rule): module_path=".*", # Apply to all modules weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, + bwd_qtype=jnp.float8_e4m3fn, bwd_use_original_residuals=True, disable_channelwise_axes=True, # per_tensor calibration weight_calibration_method=config_fp8_full.quantization_calibration_method,