Skip to content

Commit cf051eb

Browse files
shuningjinGoogle-ML-Automation
authored andcommitted
No public description
COPYBARA_INTEGRATE_REVIEW=#3405 from AI-Hypercomputer:shuningjin-qwix1 17800bf PiperOrigin-RevId: 885270721
1 parent 060fcd4 commit cf051eb

3 files changed

Lines changed: 148 additions & 15 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2531,6 +2531,15 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25312531
self.use_grpo = True
25322532
else:
25332533
self.use_grpo = False
2534+
2535+
if self.use_batch_split_schedule:
2536+
if not (self.decoder_block == DecoderBlockType.DEEPSEEK and self.sparse_matmul and self.use_tokamax_gmm):
2537+
raise ValueError("Batch split only supports deepseek, with `sparse_matmul=True` and `use_tokamax_gmm=True`")
2538+
if self.quantization and not (self.use_qwix_quantization and self.quantization == "fp8_full"):
2539+
raise ValueError(
2540+
"Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`"
2541+
)
2542+
25342543
if self.opt_type == "muon" and self.decoder_block not in [
25352544
DecoderBlockType.DEEPSEEK,
25362545
DecoderBlockType.QWEN3,
@@ -2539,7 +2548,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25392548
]:
25402549
raise ValueError(
25412550
"Muon dimension numbers haven't been tested for this model. Run this command first: "
2542-
f"`python3 -m MaxText.muon_utils {self.model_name} True`"
2551+
f"`python3 -m maxtext.utils.muon_utils {self.model_name} True`"
25432552
)
25442553
if self.force_q_layout and not self.use_jax_splash:
25452554
raise ValueError("`force_q_layout` can only be true if `use_jax_splash` is also true.")

src/maxtext/layers/quantizations.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import functools
1818
import json
1919
import re
20-
from typing import Tuple, Sequence
20+
from typing import Tuple, Sequence, Callable
2121
from dataclasses import dataclass
2222

2323
from aqt.jax.v2 import config as aqt_config
@@ -27,6 +27,7 @@
2727
from aqt.jax.v2 import calibration
2828

2929
import qwix
30+
from qwix._src.core import dot_general_qt
3031

3132
import jax
3233
import jax.numpy as jnp
@@ -194,6 +195,88 @@ def einsum(self, mesh_axes: Tuple[str, ...] = ()):
194195
return aqt_einsum
195196

196197

198+
@dataclass
199+
class QwixQuantization:
200+
"""Configures Qwix quantization github.com/google/qwix, for training only."""
201+
202+
quant_mode = "train" # needed by external call
203+
act_calibration_method: str = "absmax"
204+
weight_calibration_method: str = "absmax"
205+
bwd_calibration_method: str = "absmax"
206+
207+
def _get_fp8_full_qwix_config(self) -> dot_general_qt.DotGeneralQtConfig:
208+
"""Returns Qwix dot_general config for fp8_full quantization."""
209+
return dot_general_qt.DotGeneralQtConfig(
210+
lhs_qtype=jnp.float8_e4m3fn, # activation
211+
rhs_qtype=jnp.float8_e4m3fn, # weight
212+
dlhs_grad_qtype=jnp.float8_e5m2, # activation gradient
213+
drhs_grad_qtype=jnp.float8_e5m2, # weight gradient
214+
lhs_calibration_method=self.act_calibration_method,
215+
rhs_calibration_method=self.weight_calibration_method,
216+
dlhs_grad_calibration_method=self.bwd_calibration_method,
217+
drhs_grad_calibration_method=self.bwd_calibration_method,
218+
tile_size=None,
219+
)
220+
221+
def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
222+
"""Returns Qwix dot_general."""
223+
return functools.partial(QwixDotGeneral, config=self._get_fp8_full_qwix_config())
224+
225+
def einsum(self, mesh_axes: Tuple[str, ...] = ()):
226+
"""Returns Qwix einsum."""
227+
return QwixEinsum(config=self._get_fp8_full_qwix_config())
228+
229+
230+
class QwixDotGeneral(nn.Module):
231+
"""A callable class for Qwix dot_general."""
232+
233+
config: dot_general_qt.DotGeneralQtConfig
234+
235+
@nn.compact
236+
def __call__(
237+
self,
238+
lhs: jax.Array,
239+
rhs: jax.Array,
240+
dimension_numbers: jax.lax.DotDimensionNumbers,
241+
precision: jax.lax.PrecisionLike = None,
242+
preferred_element_type: jax.typing.DTypeLike | None = None,
243+
*,
244+
out_sharding=None,
245+
) -> jax.Array:
246+
247+
return dot_general_qt.dot_general_qt(lhs, rhs, dimension_numbers, self.config)
248+
249+
250+
class QwixEinsum(nn.Module):
251+
"""A callable class for Qwix einsum."""
252+
253+
config: dot_general_qt.DotGeneralQtConfig
254+
255+
@nn.compact
256+
def __call__(
257+
self,
258+
einsum_str: str,
259+
*operands: jax.Array,
260+
precision: jax.lax.PrecisionLike = None,
261+
preferred_element_type: jax.typing.DTypeLike | None = None,
262+
_dot_general: Callable[..., jax.Array] | None = None,
263+
out_sharding=None,
264+
) -> jax.Array:
265+
266+
def custom_dot_general(*args, **kwargs):
267+
return dot_general_qt.dot_general_qt(*args[:3], self.config)
268+
269+
with jax.disable_jit():
270+
return jnp.einsum(
271+
einsum_str,
272+
*operands,
273+
precision=precision,
274+
preferred_element_type=preferred_element_type,
275+
_dot_general=custom_dot_general,
276+
out_sharding=out_sharding,
277+
)
278+
279+
197280
@dataclass
198281
class Fp8Quantization(Quantization):
199282
"""Configures Fp8 quantization for NVIDIA GPUs"""
@@ -539,13 +622,20 @@ def get_quant_mode(quant_mode_str: str = "train"):
539622
return aqt_flax.QuantMode.SERVE
540623
elif quant_mode_str == "convert":
541624
return aqt_flax.QuantMode.CONVERT
542-
else:
543-
raise ValueError(f"Invalid quantization mode {quant_mode_str}.")
544-
return None
625+
raise ValueError(f"Invalid quantization mode {quant_mode_str}.")
545626

546627

547628
def configure_quantization(config: Config, quant_mode_str: str = "train"):
548629
"""Configure quantization based on user config and quant mode."""
630+
if config.use_batch_split_schedule and config.quantization:
631+
if not (config.use_qwix_quantization and config.quantization == "fp8_full"):
632+
raise ValueError("Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`")
633+
return QwixQuantization(
634+
weight_calibration_method=config.weight_quantization_calibration_method,
635+
act_calibration_method=config.act_quantization_calibration_method,
636+
bwd_calibration_method=config.bwd_quantization_calibration_method,
637+
)
638+
549639
if config.use_qwix_quantization:
550640
return None
551641
quant_cfg = _get_quant_config(config)
@@ -726,7 +816,8 @@ def get_qt_provider(config):
726816

727817
def maybe_quantize_model(model, config):
728818
"""Quantize the model if quantization is enabled."""
729-
if config.use_qwix_quantization:
819+
# Batch split is not using Qwix's interception feature but manual plumbing
820+
if config.use_qwix_quantization and not config.use_batch_split_schedule:
730821
quantization_provider = get_qt_provider(config)
731822
if quantization_provider:
732823
model = qwix.quantize_model(model, quantization_provider)

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ def batch_split_schedule(
406406
rope_factor=cfg.rope_factor,
407407
mscale=cfg.mscale,
408408
dtype=cfg.dtype,
409+
quant=quant,
409410
)
410411

411412
xs = moe(
@@ -418,6 +419,7 @@ def batch_split_schedule(
418419
expert_axis_name="expert",
419420
use_gather_mosaic_kernel=False,
420421
config=cfg,
422+
quant=quant,
421423
)
422424
return xs
423425

@@ -440,7 +442,21 @@ def with_data_parallel_constraint(x, mesh):
440442
return jax.lax.with_sharding_constraint(x, jax.NamedSharding(mesh, activation_pspec))
441443

442444

443-
def dot(x, y, axes=1):
445+
def dot(x, y, quant=None, axes=1):
446+
"""Computes the dot product of two arrays, optionally using quantization."""
447+
if quant is not None:
448+
# Convert axes to jax.lax.dot_general dimension_numbers
449+
if isinstance(axes, int):
450+
x_contract = tuple(range(x.ndim - axes, x.ndim))
451+
y_contract = tuple(range(axes))
452+
else:
453+
x_contract, y_contract = axes
454+
dimension_numbers = ((x_contract, y_contract), ((), ()))
455+
# Instantiate and call qwix dot_general
456+
custom_dot = quant.dot_general_cls()()
457+
return custom_dot(lhs=x, rhs=y, dimension_numbers=dimension_numbers)
458+
459+
# Unquantized
444460
return jnp.tensordot(x, y, axes=axes)
445461

446462

@@ -466,6 +482,7 @@ def mla_with_norms(
466482
rope_factor,
467483
mscale,
468484
dtype,
485+
quant,
469486
):
470487
"""Performs MLA with pre- and post-normalization."""
471488
(pre_attn_scale, post_attn_scale), attn_ws = weights
@@ -500,6 +517,7 @@ def fn(args):
500517
dtype=dtype,
501518
mscale=mscale,
502519
attention_op_fn=attn_op,
520+
quant=quant,
503521
),
504522
mesh,
505523
)
@@ -535,6 +553,7 @@ def mla(
535553
mscale,
536554
attention_op_fn,
537555
dtype,
556+
quant,
538557
):
539558
"""Performs MLA."""
540559
(
@@ -563,6 +582,7 @@ def mla(
563582
dtype=dtype,
564583
qk_nope_head_dim=qk_nope_head_dim,
565584
mscale=mscale,
585+
quant=quant,
566586
)
567587
query = jax.ad_checkpoint.checkpoint_name(query, "query_proj")
568588
key, value = kv_projection(
@@ -583,6 +603,7 @@ def mla(
583603
dtype=dtype,
584604
qk_nope_head_dim=qk_nope_head_dim,
585605
num_query_heads=num_query_heads,
606+
quant=quant,
586607
)
587608
key = jax.ad_checkpoint.checkpoint_name(key, "key_proj")
588609
value = jax.ad_checkpoint.checkpoint_name(value, "value_proj")
@@ -595,7 +616,7 @@ def mla(
595616
cached_values=[None, None],
596617
)
597618
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
598-
out = dot(out, out_weights, axes=2)
619+
out = dot(out, out_weights, quant=quant, axes=2)
599620
out = jax.ad_checkpoint.checkpoint_name(out, "out_proj")
600621
return out
601622

@@ -618,6 +639,7 @@ def query_projection(
618639
rope_factor,
619640
dtype,
620641
mscale,
642+
quant,
621643
):
622644
"""Performs query projection."""
623645
# Set softmax scaling.
@@ -628,15 +650,15 @@ def query_projection(
628650
softmax_scale = softmax_scale * m * m
629651

630652
# LoRA path
631-
low_rank_q = dot(inputs_q, wq_a_weights)
653+
low_rank_q = dot(inputs_q, wq_a_weights, quant=quant)
632654
low_rank_q = rms_norm(
633655
low_rank_q,
634656
q_norm_scale_weights,
635657
epsilon=epsilon,
636658
dtype=dtype,
637659
)
638660
low_rank_q = jax.ad_checkpoint.checkpoint_name(low_rank_q, "mla_q")
639-
q = dot(low_rank_q, wq_b_weights)
661+
q = dot(low_rank_q, wq_b_weights, quant=quant)
640662

641663
# Split into non-positional and rotary parts.
642664
q_nope, q_pe = jnp.split(q, [qk_nope_head_dim], axis=-1)
@@ -675,9 +697,10 @@ def kv_projection(
675697
dtype,
676698
qk_nope_head_dim,
677699
num_query_heads,
700+
quant,
678701
):
679702
"""Performs KV projection."""
680-
low_rank = dot(inputs, wkv_a_weights)
703+
low_rank = dot(inputs, wkv_a_weights, quant=quant)
681704
low_rank_main, low_rank_rope = jnp.split(low_rank, [kv_lora_rank], axis=-1)
682705
low_rank_main = rms_norm(
683706
low_rank_main,
@@ -706,12 +729,13 @@ def kv_projection(
706729
wkv_b_weights,
707730
qk_nope_head_dim=qk_nope_head_dim,
708731
num_query_heads=num_query_heads,
732+
quant=quant,
709733
)
710734

711735

712-
def get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads):
736+
def get_key_value(low_rank_main, key_rope, wkv_b_weights, *, qk_nope_head_dim, num_query_heads, quant):
713737
"""Gets key and value from compressed KV latent vector and key rope."""
714-
kv_out = dot(low_rank_main, wkv_b_weights)
738+
kv_out = dot(low_rank_main, wkv_b_weights, quant=quant)
715739

716740
# Split kv_out into key_nope and value parts.
717741
key_nope, value = jnp.split(kv_out, [qk_nope_head_dim], axis=-1)
@@ -807,6 +831,7 @@ def moe(
807831
expert_axis_name,
808832
use_gather_mosaic_kernel,
809833
config,
834+
quant,
810835
):
811836
"""Performs dropless MoE with tensor/expert parallelism."""
812837
xs, ys = list(zip(*inputs))
@@ -821,6 +846,7 @@ def moe(
821846
expert_axis_name=expert_axis_name,
822847
use_gather_mosaic_kernel=use_gather_mosaic_kernel,
823848
config=config,
849+
quant=quant,
824850
),
825851
mesh,
826852
)
@@ -851,9 +877,10 @@ def expert_selection(
851877
num_experts,
852878
num_experts_per_tok,
853879
routed_scaling_factor,
880+
quant,
854881
):
855882
"""Selects experts for each token and calculates group sizes for each expert."""
856-
pre_bias_logits = jax.nn.sigmoid(dot(x, routing_kernel))
883+
pre_bias_logits = jax.nn.sigmoid(dot(x, routing_kernel, quant=quant))
857884
logits = pre_bias_logits + routing_bias
858885

859886
selected_experts, weights = expert_indices_and_weights(
@@ -1067,6 +1094,7 @@ def route_compute_unroute(
10671094
use_gather_mosaic_kernel,
10681095
config,
10691096
mesh,
1097+
quant,
10701098
):
10711099
"""Routes, processes, and unroutes activations."""
10721100
orig_shape = xs[0].shape
@@ -1078,7 +1106,9 @@ def route_compute_unroute(
10781106

10791107
def route_fn(inputs):
10801108
# Shared expert.
1081-
y = dot(jax.nn.silu(dot(inputs, shared_w0)) * dot(inputs, shared_w1), shared_wo)
1109+
y = dot(
1110+
jax.nn.silu(dot(inputs, shared_w0, quant=quant)) * dot(inputs, shared_w1, quant=quant), shared_wo, quant=quant
1111+
)
10821112

10831113
inputs = jnp.reshape(inputs, (-1, inputs.shape[-1]))
10841114
selected_experts, weights, group_sizes = expert_selection(
@@ -1088,6 +1118,7 @@ def route_fn(inputs):
10881118
num_experts=num_experts,
10891119
num_experts_per_tok=num_experts_per_tok,
10901120
routed_scaling_factor=routed_scaling_factor,
1121+
quant=quant,
10911122
)
10921123
x, selected_experts, weights, group_sizes = route(
10931124
inputs,
@@ -1140,6 +1171,7 @@ def process_activations(
11401171
expert_axis_name,
11411172
use_gather_mosaic_kernel,
11421173
config,
1174+
quant,
11431175
):
11441176
"""Processes activations, which are fully sharded on the batch axis, with tensor/expert sharded weights."""
11451177
activation_pspec = jax.sharding.PartitionSpec(
@@ -1164,6 +1196,7 @@ def process_activations(
11641196
use_gather_mosaic_kernel=use_gather_mosaic_kernel,
11651197
config=config,
11661198
mesh=mesh,
1199+
quant=quant,
11671200
),
11681201
mesh=mesh,
11691202
in_specs=(

0 commit comments

Comments
 (0)