Skip to content

Commit 0815b96

Browse files
committed
fix error: conv_general_qt requires the same bwd_qtype as weight_qtype
1 parent fd53cde commit 0815b96

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def get_fp8_config(cls, quantization_calibration_method: str):
265265
module_path=".*", # Apply to all modules
266266
weight_qtype=jnp.float8_e4m3fn,
267267
act_qtype=jnp.float8_e4m3fn,
268-
bwd_qtype=jnp.float8_e5m2,
268+
bwd_qtype=jnp.float8_e4m3fn,
269269
bwd_use_original_residuals=True,
270270
disable_channelwise_axes=True, # per_tensor calibration
271271
weight_calibration_method=quantization_calibration_method,

0 commit comments

Comments
 (0)