Skip to content

Commit 22a87b8

Browse files
committed
fix unit test
1 parent 6aee280 commit 22a87b8

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,8 @@ def create_real_rule_instance(*args, **kwargs):
343343
act_qtype=jnp.float8_e4m3fn,
344344
bwd_qtype=jnp.float8_e5m2,
345345
disable_channelwise_axes=True, # per_tensor calibration
346-
weight_calibration_method=config_fp8_full.quantization_calibration_method,
347-
act_calibration_method=config_fp8_full.quantization_calibration_method,
346+
weight_calibration_method="fixed,-224,224",
347+
act_calibration_method="fixed,-224,224",
348348
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
349349
op_names=("dot_general", "einsum"),
350350
),
@@ -354,8 +354,8 @@ def create_real_rule_instance(*args, **kwargs):
354354
act_qtype=jnp.float8_e4m3fn,
355355
bwd_qtype=jnp.float8_e4m3fn,
356356
disable_channelwise_axes=True, # per_tensor calibration
357-
weight_calibration_method=config_fp8_full.quantization_calibration_method,
358-
act_calibration_method=config_fp8_full.quantization_calibration_method,
357+
weight_calibration_method="fixed,-224,224",
358+
act_calibration_method="fixed,-224,224",
359359
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
360360
op_names=("conv_general_dilated"),
361361
),

0 commit comments

Comments
 (0)