@@ -343,8 +343,8 @@ def create_real_rule_instance(*args, **kwargs):
343343 act_qtype = jnp .float8_e4m3fn ,
344344 bwd_qtype = jnp .float8_e5m2 ,
345345 disable_channelwise_axes = True , # per_tensor calibration
346- weight_calibration_method = config_fp8_full . quantization_calibration_method ,
347- act_calibration_method = config_fp8_full . quantization_calibration_method ,
346+ weight_calibration_method = "fixed,-224,224" ,
347+ act_calibration_method = "fixed,-224,224" ,
348348 bwd_calibration_method = config_fp8_full .quantization_calibration_method ,
349349 op_names = ("dot_general" , "einsum" ),
350350 ),
@@ -354,8 +354,8 @@ def create_real_rule_instance(*args, **kwargs):
354354 act_qtype = jnp .float8_e4m3fn ,
355355 bwd_qtype = jnp .float8_e4m3fn ,
356356 disable_channelwise_axes = True , # per_tensor calibration
357- weight_calibration_method = config_fp8_full . quantization_calibration_method ,
358- act_calibration_method = config_fp8_full . quantization_calibration_method ,
357+ weight_calibration_method = "fixed,-224,224" ,
358+ act_calibration_method = "fixed,-224,224" ,
359359 bwd_calibration_method = config_fp8_full .quantization_calibration_method ,
360360 op_names = ("conv_general_dilated" ),
361361 ),
0 commit comments