Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ def get_fp8_config(cls, config: HyperParameters):
weight_qtype=jnp.float8_e4m3fn,
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e5m2,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=config.quantization_calibration_method,
act_calibration_method=config.quantization_calibration_method,
Expand All @@ -309,7 +308,6 @@ def get_fp8_config(cls, config: HyperParameters):
weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes
act_qtype=jnp.float8_e4m3fn,
bwd_qtype=jnp.float8_e4m3fn,
bwd_use_original_residuals=True,
disable_channelwise_axes=True, # per_tensor calibration
weight_calibration_method=config.quantization_calibration_method,
act_calibration_method=config.quantization_calibration_method,
Expand Down
Loading