@@ -246,10 +246,10 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
246246 def get_basic_config (cls , dtype , config : HyperParameters ):
247247 rules = [
248248 qwix .QtRule (
249- module_path = config .qwix_module_path , # Apply to all modules
249+ module_path = config .qwix_module_path ,
250250 weight_qtype = dtype ,
251251 act_qtype = dtype ,
252- op_names = ("dot_general" ,),
252+ op_names = ("dot_general" ,"einsum" , "conv_general_dilated" ),
253253 )
254254 ]
255255 return rules
@@ -263,7 +263,7 @@ def get_fp8_config(cls, config: HyperParameters):
263263 """
264264 rules = [
265265 qwix .QtRule (
266- module_path = config .qwix_module_path , # Apply to all modules
266+ module_path = config .qwix_module_path ,
267267 weight_qtype = jnp .float8_e4m3fn ,
268268 act_qtype = jnp .float8_e4m3fn ,
269269 bwd_qtype = jnp .float8_e5m2 ,
@@ -272,7 +272,19 @@ def get_fp8_config(cls, config: HyperParameters):
272272 weight_calibration_method = config .quantization_calibration_method ,
273273 act_calibration_method = config .quantization_calibration_method ,
274274 bwd_calibration_method = config .quantization_calibration_method ,
275- op_names = ("dot_general" ,),
275+ op_names = ("dot_general" ,"einsum" ),
276+ ),
277+ qwix .QtRule (
278+ module_path = config .qwix_module_path ,
279+ weight_qtype = jnp .float8_e4m3fn , # conv_general_dilated requires the same dtypes
280+ act_qtype = jnp .float8_e4m3fn ,
281+ bwd_qtype = jnp .float8_e4m3fn ,
282+ bwd_use_original_residuals = True ,
283+ disable_channelwise_axes = True , # per_tensor calibration
284+ weight_calibration_method = config .quantization_calibration_method ,
285+ act_calibration_method = config .quantization_calibration_method ,
286+ bwd_calibration_method = config .quantization_calibration_method ,
287+ op_names = ("conv_general_dilated" ),
276288 )
277289 ]
278290 return rules
0 commit comments