@@ -640,6 +640,19 @@ def dot_general(self, *args, **kwargs):
640640 return nn .NANOOFp8DotGeneralOp (name = op_id )(* args , ** kwargs )
641641
642642
643+ def get_fp8_full_qwix_rule (config : Config ):
644+ return qwix .QtRule (
645+ module_path = "decoder/.*layers.*" ,
646+ weight_qtype = jnp .float8_e4m3fn ,
647+ act_qtype = jnp .float8_e4m3fn ,
648+ bwd_qtype = jnp .float8_e5m2 ,
649+ weight_calibration_method = config .weight_quantization_calibration_method ,
650+ act_calibration_method = config .act_quantization_calibration_method ,
651+ bwd_calibration_method = config .bwd_quantization_calibration_method ,
652+ op_names = ("dot_general" , "gmm" , "ragged_dot" ),
653+ )
654+
655+
643656def get_quantization_rule (config : Config ):
644657 match config .quantization :
645658 case "int8" :
@@ -661,16 +674,7 @@ def get_quantization_rule(config: Config):
661674 op_names = ("dot_general" ,),
662675 )
663676 case "fp8_full" :
664- return qwix .QtRule (
665- module_path = "decoder/.*layers.*" ,
666- weight_qtype = jnp .float8_e4m3fn ,
667- act_qtype = jnp .float8_e4m3fn ,
668- bwd_qtype = jnp .float8_e5m2 ,
669- weight_calibration_method = config .weight_quantization_calibration_method ,
670- act_calibration_method = config .act_quantization_calibration_method ,
671- bwd_calibration_method = config .bwd_quantization_calibration_method ,
672- op_names = ("dot_general" , "gmm" , "ragged_dot" ),
673- )
677+ return get_fp8_full_qwix_rule (config )
674678 case "fp8_gpu" :
675679 return qwix .QtRule (
676680 module_path = "decoder/.*layers.*" ,
@@ -808,7 +812,7 @@ def generate_quantizer_set(self, postfix: str = ""):
808812 postfix = postfix ,
809813 variable_collection = OVERWRITE_WITH_GRADIENT ,
810814 quantization_checkpoint_name = "quantization" ,
811- fp8_recipe = fp8_recipe
815+ fp8_recipe = fp8_recipe ,
812816 )
813817
814818 @nn .compact
0 commit comments