File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -81,6 +81,7 @@ class QuantizationType(str, Enum):
8181 """Supported quantization schemes."""
8282
8383 NONE = ""
84+ INT4 = "int4"
8485 INT8 = "int8"
8586 INTMP = "intmp"
8687 FP8 = "fp8"
Original file line number Diff line number Diff line change @@ -655,6 +655,15 @@ def get_fp8_full_qwix_rule(config: Config):
655655
656656def get_quantization_rule (config : Config ):
657657 match config .quantization :
658+ case "int4" :
659+ return qwix .QtRule (
660+ module_path = "decoder/.*layers.*" ,
661+ weight_qtype = jnp .int4 ,
662+ act_qtype = jnp .int4 ,
663+ bwd_qtype = jnp .int4 ,
664+ bwd_weight_grad_tile_size = 1 / config .quantization_local_shard_count ,
665+ op_names = ("dot_general" ,),
666+ )
658667 case "int8" :
659668 return qwix .QtRule (
660669 module_path = "decoder/.*layers.*" ,
@@ -702,6 +711,8 @@ def get_qt_provider(config):
702711 match config .quantization :
703712 case "int8" :
704713 return qwix .QtProvider ([get_quantization_rule (config )])
714+ case "int4" :
715+ return qwix .QtProvider ([get_quantization_rule (config )])
705716 case "fp8" :
706717 return qwix .QtProvider ([get_quantization_rule (config )])
707718 case "fp8_full" :
You can’t perform that action at this time.
0 commit comments