diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 36846a148..115c90545 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -297,7 +297,6 @@ def get_fp8_config(cls, config: HyperParameters): weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e5m2, - bwd_use_original_residuals=True, disable_channelwise_axes=True, # per_tensor calibration weight_calibration_method=config.quantization_calibration_method, act_calibration_method=config.quantization_calibration_method, @@ -309,7 +308,6 @@ def get_fp8_config(cls, config: HyperParameters): weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, - bwd_use_original_residuals=True, disable_channelwise_axes=True, # per_tensor calibration weight_calibration_method=config.quantization_calibration_method, act_calibration_method=config.quantization_calibration_method, diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 2a3da9094..34f0ef642 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -342,7 +342,6 @@ def create_real_rule_instance(*args, **kwargs): weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e5m2, - bwd_use_original_residuals=True, disable_channelwise_axes=True, # per_tensor calibration weight_calibration_method=config_fp8_full.quantization_calibration_method, act_calibration_method=config_fp8_full.quantization_calibration_method, @@ -354,7 +353,6 @@ def create_real_rule_instance(*args, **kwargs): weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, - bwd_use_original_residuals=True, disable_channelwise_axes=True, # per_tensor calibration weight_calibration_method=config_fp8_full.quantization_calibration_method, act_calibration_method=config_fp8_full.quantization_calibration_method,