Skip to content

Commit 238a410

Browse files
quantized ragged dot maxtext integration
PiperOrigin-RevId: 836301101
1 parent 9dbec18 commit 238a410

3 files changed

Lines changed: 56 additions & 41 deletions

File tree

src/MaxText/layers/moe.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import enum
1919
import functools
2020
import math
21+
import random
2122
from typing import Iterable, Optional, Tuple, Union
2223

2324
from aqt.jax.v2 import aqt_tensor as aqt
@@ -860,7 +861,14 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments):
860861
if kernel.bias or kernel.sparsity_mask or len(kernel.scale) > 1:
861862
raise ValueError("Unsupported usecase for ragged_dot with quantized kernel.")
862863
rhs_inputs = kernel.qvalue
863-
with set_xla_metadata(ragged_dot_tiling=",".join([str(t) for t in tiling])):
864+
if self.config.use_qwix_quantization:
865+
# Use full contraction for QWIX quantization to allow quantization
866+
# fusion (max reduce over contracting dimension).
867+
tiling = (tiling[0], k, tiling[2])
868+
with set_xla_metadata(
869+
ragged_dot_tiling=",".join([str(t) for t in tiling]),
870+
mosaic_fusion_group=f"{random.randint(0, 1000000000)}",
871+
):
864872
output = jax.lax.ragged_dot(
865873
lhs=inputs,
866874
rhs=rhs_inputs,

src/MaxText/layers/quantizations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def get_quantization_rule(config: Config):
672672
weight_calibration_method=config.weight_quantization_calibration_method,
673673
act_calibration_method=config.act_quantization_calibration_method,
674674
bwd_calibration_method=config.bwd_quantization_calibration_method,
675-
op_names=("dot_general", "gmm"),
675+
op_names=("dot_general", "gmm", "ragged_dot"),
676676
)
677677
case "fp8_gpu":
678678
return qwix.QtRule(

tests/train_using_ragged_dot_smoke_test.py

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,54 +15,61 @@
1515
"""Smoke test for MoE using ragged_dot."""
1616

1717
import os
18-
import unittest
19-
from tempfile import gettempdir
18+
import tempfile
2019

2120
from absl.testing import absltest
21+
from absl.testing import parameterized
22+
from MaxText import globals as maxtext_globals
23+
from MaxText import train
2224

23-
from MaxText.globals import MAXTEXT_PKG_DIR
24-
from MaxText.train import main as train_main
25+
train_main = train.main
26+
MAXTEXT_PKG_DIR = maxtext_globals.MAXTEXT_PKG_DIR
27+
gettempdir = tempfile.gettempdir
2528

2629

27-
class Train(unittest.TestCase):
30+
class Train(parameterized.TestCase):
2831
"""Smoke test for MoE using ragged_dot in G3 only."""
2932

30-
def test_tiny_config(self):
33+
@parameterized.named_parameters(
34+
{"testcase_name": "not_quantized", "quantization": ""},
35+
{"testcase_name": "fp8_full", "quantization": "fp8_full"},
36+
)
37+
def test_tiny_config(self, quantization: str):
3138
test_tmpdir = os.environ.get("TEST_TMPDIR", gettempdir())
3239
outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", test_tmpdir)
33-
train_main(
34-
[
35-
None,
36-
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
37-
f"base_output_directory={test_tmpdir}",
38-
"run_name=ragged_dot_smoke_test",
39-
"base_emb_dim=128",
40-
"base_num_query_heads=4",
41-
"base_num_kv_heads=4",
42-
"base_mlp_dim=128",
43-
"base_moe_mlp_dim=128",
44-
"base_num_decoder_layers=8",
45-
"head_dim=128",
46-
# TODO(b/441100085): When changing the decoder_block we might
47-
# need to adjust the tiling.
48-
"decoder_block=deepseek",
49-
"attention_type=mla",
50-
"num_experts=2",
51-
# Enable sparse_matmul.
52-
"sparse_matmul=True",
53-
# Enable ragged_dot.
54-
"megablox=False",
55-
"per_device_batch_size=2",
56-
"max_target_length=1024",
57-
"dataset_type=synthetic",
58-
"steps=10",
59-
"enable_checkpointing=False",
60-
"enable_goodput_recording=False",
61-
"enable_checkpoint_cloud_logger=False",
62-
"monitor_goodput=False",
63-
f"metrics_file={os.path.join(outputs_dir, 'metrics.json')}",
64-
]
65-
)
40+
train_main([
41+
None,
42+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
43+
f"base_output_directory={test_tmpdir}",
44+
"run_name=ragged_dot_smoke_test",
45+
"base_emb_dim=128",
46+
"base_num_query_heads=4",
47+
"base_num_kv_heads=4",
48+
"base_mlp_dim=128",
49+
"base_moe_mlp_dim=128",
50+
"base_num_decoder_layers=8",
51+
"head_dim=128",
52+
# TODO(b/441100085): When changing the decoder_block we might
53+
# need to adjust the tiling.
54+
"decoder_block=deepseek",
55+
"attention_type=mla",
56+
"num_experts=2",
57+
# Enable sparse_matmul.
58+
"sparse_matmul=True",
59+
# Enable ragged_dot.
60+
"megablox=False",
61+
f'quantization="{quantization}"',
62+
"use_qwix_quantization=True",
63+
"per_device_batch_size=2",
64+
"max_target_length=1024",
65+
"dataset_type=synthetic",
66+
"steps=10",
67+
"enable_checkpointing=False",
68+
"enable_goodput_recording=False",
69+
"enable_checkpoint_cloud_logger=False",
70+
"monitor_goodput=False",
71+
f"metrics_file={os.path.join(outputs_dir, 'metrics.json')}",
72+
])
6673

6774

6875
if __name__ == "__main__":

0 commit comments

Comments
 (0)