Skip to content

Commit 5b3654d

Browse files
Merge pull request #3319 from AI-Hypercomputer:amandaliang
PiperOrigin-RevId: 879178990
2 parents 8a1b34d + 5f60e5e commit 5b3654d

3 files changed

Lines changed: 19 additions & 12 deletions

File tree

src/maxtext/kernels/megablox/ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ def gmm(
4444
weight_gather_axes: List[Tuple[str, int]] | None = None,
4545
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
4646
combine_scopes: bool = False,
47+
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
48+
qwix_rule: qwix.QtRule | None = None,
4749
):
4850
"""Grouped matrix multiplication operation."""
4951
quantization_rule = None
5052
if use_qwix_quantization:
5153
# get_current_rule has to be called outside of the _gmm_fwd function.
52-
quantization_rule = qpl.get_current_rule("gmm")
54+
quantization_rule = qwix_rule if qwix_rule else qpl.get_current_rule("gmm")
5355
if quantization_rule and not isinstance(quantization_rule, qwix.QtRule):
5456
raise ValueError("Expect a QtRule for quantized training.")
5557
else:

src/maxtext/layers/quantizations.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,19 @@ def dot_general(self, *args, **kwargs):
640640
return nn.NANOOFp8DotGeneralOp(name=op_id)(*args, **kwargs)
641641

642642

643+
def get_fp8_full_qwix_rule(config: Config):
644+
return qwix.QtRule(
645+
module_path="decoder/.*layers.*",
646+
weight_qtype=jnp.float8_e4m3fn,
647+
act_qtype=jnp.float8_e4m3fn,
648+
bwd_qtype=jnp.float8_e5m2,
649+
weight_calibration_method=config.weight_quantization_calibration_method,
650+
act_calibration_method=config.act_quantization_calibration_method,
651+
bwd_calibration_method=config.bwd_quantization_calibration_method,
652+
op_names=("dot_general", "gmm", "ragged_dot"),
653+
)
654+
655+
643656
def get_quantization_rule(config: Config):
644657
match config.quantization:
645658
case "int8":
@@ -661,16 +674,7 @@ def get_quantization_rule(config: Config):
661674
op_names=("dot_general",),
662675
)
663676
case "fp8_full":
664-
return qwix.QtRule(
665-
module_path="decoder/.*layers.*",
666-
weight_qtype=jnp.float8_e4m3fn,
667-
act_qtype=jnp.float8_e4m3fn,
668-
bwd_qtype=jnp.float8_e5m2,
669-
weight_calibration_method=config.weight_quantization_calibration_method,
670-
act_calibration_method=config.act_quantization_calibration_method,
671-
bwd_calibration_method=config.bwd_quantization_calibration_method,
672-
op_names=("dot_general", "gmm", "ragged_dot"),
673-
)
677+
return get_fp8_full_qwix_rule(config)
674678
case "fp8_gpu":
675679
return qwix.QtRule(
676680
module_path="decoder/.*layers.*",
@@ -808,7 +812,7 @@ def generate_quantizer_set(self, postfix: str = ""):
808812
postfix=postfix,
809813
variable_collection=OVERWRITE_WITH_GRADIENT,
810814
quantization_checkpoint_name="quantization",
811-
fp8_recipe=fp8_recipe
815+
fp8_recipe=fp8_recipe,
812816
)
813817

814818
@nn.compact

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,7 @@ def gmm(
815815
weight_gather_axes=weight_gather_axes,
816816
input_buffer_count=input_buffer_count,
817817
combine_scopes=combine_scopes,
818+
qwix_rule=quantizations.get_fp8_full_qwix_rule(config),
818819
)
819820
else:
820821
output = tokamax.ragged_dot(

0 commit comments

Comments
 (0)