Skip to content

Commit d93e0ea

Browse files
committed
change quantization calibration method
1 parent 7cbb714 commit d93e0ea

2 files changed

Lines changed: 10 additions & 8 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,10 @@ quantization: ''
319319
quantization_local_shard_count: -1
320320
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
321321
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
322-
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
323-
quantization_calibration_method: "absmax"
322+
# Quantization calibration method used for weights, activations and bwd. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
323+
weight_quantization_calibration_method: "absmax"
324+
act_quantization_calibration_method: "absmax"
325+
bwd_quantization_calibration_method: "absmax"
324326
qwix_module_path: ".*"
325327

326328
# Eval model on per eval_every steps. -1 means don't eval.

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,9 @@ def get_fp8_config(cls, config: HyperParameters):
302302
act_qtype=jnp.float8_e4m3fn,
303303
bwd_qtype=jnp.float8_e5m2,
304304
disable_channelwise_axes=True, # per_tensor calibration
305-
weight_calibration_method=config.quantization_calibration_method,
306-
act_calibration_method=config.quantization_calibration_method,
307-
bwd_calibration_method=config.quantization_calibration_method,
305+
weight_calibration_method=config.weight_quantization_calibration_method,
306+
act_calibration_method=config.act_quantization_calibration_method,
307+
bwd_calibration_method=config.bwd_quantization_calibration_method,
308308
op_names=("dot_general", "einsum"),
309309
),
310310
qwix.QtRule(
@@ -313,9 +313,9 @@ def get_fp8_config(cls, config: HyperParameters):
313313
act_qtype=jnp.float8_e4m3fn,
314314
bwd_qtype=jnp.float8_e4m3fn,
315315
disable_channelwise_axes=True, # per_tensor calibration
316-
weight_calibration_method=config.quantization_calibration_method,
317-
act_calibration_method=config.quantization_calibration_method,
318-
bwd_calibration_method=config.quantization_calibration_method,
316+
weight_calibration_method=config.weight_quantization_calibration_method,
317+
act_calibration_method=config.act_quantization_calibration_method,
318+
bwd_calibration_method=config.bwd_quantization_calibration_method,
319319
op_names=("conv_general_dilated"),
320320
),
321321
]

0 commit comments

Comments
 (0)