@@ -302,8 +302,8 @@ def get_fp8_config(cls, config: HyperParameters):
302302 act_qtype = jnp .float8_e4m3fn ,
303303 bwd_qtype = jnp .float8_e5m2 ,
304304 disable_channelwise_axes = True , # per_tensor calibration
305- weight_calibration_method = config . quantization_calibration_method ,
306- act_calibration_method = config . quantization_calibration_method ,
305+ weight_calibration_method = "fixed,-224,224" ,
306+ act_calibration_method = "fixed,-224,224" ,
307307 bwd_calibration_method = config .quantization_calibration_method ,
308308 op_names = ("dot_general" , "einsum" ),
309309 ),
@@ -313,8 +313,8 @@ def get_fp8_config(cls, config: HyperParameters):
313313 act_qtype = jnp .float8_e4m3fn ,
314314 bwd_qtype = jnp .float8_e4m3fn ,
315315 disable_channelwise_axes = True , # per_tensor calibration
316- weight_calibration_method = config . quantization_calibration_method ,
317- act_calibration_method = config . quantization_calibration_method ,
316+ weight_calibration_method = "fixed,-224,224" ,
317+ act_calibration_method = "fixed,-224,224" ,
318318 bwd_calibration_method = config .quantization_calibration_method ,
319319 op_names = ("conv_general_dilated" ),
320320 ),
0 commit comments