Skip to content

Commit f11f550

Browse files
Add support for int4.
PiperOrigin-RevId: 880998732
1 parent 76d9f94 commit f11f550

2 files changed

Lines changed: 12 additions & 0 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class QuantizationType(str, Enum):
8181
"""Supported quantization schemes."""
8282

8383
NONE = ""
84+
INT4 = "int4"
8485
INT8 = "int8"
8586
INTMP = "intmp"
8687
FP8 = "fp8"

src/maxtext/layers/quantizations.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,15 @@ def get_fp8_full_qwix_rule(config: Config):
655655

656656
def get_quantization_rule(config: Config):
657657
match config.quantization:
658+
case "int4":
659+
return qwix.QtRule(
660+
module_path="decoder/.*layers.*",
661+
weight_qtype=jnp.int4,
662+
act_qtype=jnp.int4,
663+
bwd_qtype=jnp.int4,
664+
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
665+
op_names=("dot_general",),
666+
)
658667
case "int8":
659668
return qwix.QtRule(
660669
module_path="decoder/.*layers.*",
@@ -702,6 +711,8 @@ def get_qt_provider(config):
702711
match config.quantization:
703712
case "int8":
704713
return qwix.QtProvider([get_quantization_rule(config)])
714+
case "int4":
715+
return qwix.QtProvider([get_quantization_rule(config)])
705716
case "fp8":
706717
return qwix.QtProvider([get_quantization_rule(config)])
707718
case "fp8_full":

0 commit comments

Comments
 (0)