2424from einops import rearrange
2525from .. import common_types , max_logging
2626
27+ from . import quantizations
28+
29+
2730Array = common_types .Array
2831Mesh = common_types .Mesh
2932DType = common_types .DType
3639HEAD = common_types .HEAD
3740D_KV = common_types .D_KV
3841EMBED = common_types .EMBED
42+ Quant = quantizations .AqtQuantization
43+
44+
45+ Quant = quantizations .AqtQuantization
46+
47+
48+ def _maybe_aqt_einsum (quant : Quant ):
49+ return jnp .einsum if quant is None else quant .einsum ()
3950
4051
4152class AttentionOp (nn .Module ):
@@ -51,6 +62,7 @@ class AttentionOp(nn.Module):
5162 flash_min_seq_length : int = 4096
5263 flash_block_sizes : BlockSizes = None
5364 dtype : DType = jnp .float32
65+ quant : Quant = None
5466
5567 def setup (self ):
5668 if self .attention_kernel == "cudnn_flash_te" :
@@ -585,6 +597,7 @@ class FlaxAttention(nn.Module):
585597 jax mesh is required if attention is set to flash.
586598 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
587599 Parameters `dtype`
600+ quant (`AqtQuantization`, *optional*, defaults to None)
588601
589602 """
590603
@@ -605,6 +618,7 @@ class FlaxAttention(nn.Module):
605618 value_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
606619 out_axis_names : AxisNames = (BATCH , LENGTH , HEAD )
607620 precision : jax .lax .Precision = None
621+ quant : Quant = None
608622
609623 def setup (self ):
610624
@@ -624,10 +638,13 @@ def setup(self):
624638 split_head_dim = self .split_head_dim ,
625639 flash_block_sizes = self .flash_block_sizes ,
626640 dtype = self .dtype ,
641+ quant = self .quant ,
627642 )
628643
629644 qkv_init_kernel = nn .with_logical_partitioning (nn .initializers .lecun_normal (), ("embed" , "heads" ))
630-
645+ dot_general_cls = None
646+ if self .quant :
647+ dot_general_cls = self .quant .dot_general_cls ()
631648 self .query = nn .Dense (
632649 inner_dim ,
633650 kernel_init = qkv_init_kernel ,
@@ -636,6 +653,7 @@ def setup(self):
636653 param_dtype = self .weights_dtype ,
637654 name = "to_q" ,
638655 precision = self .precision ,
656+ dot_general_cls = dot_general_cls ,
639657 )
640658
641659 self .key = nn .Dense (
@@ -646,6 +664,7 @@ def setup(self):
646664 param_dtype = self .weights_dtype ,
647665 name = "to_k" ,
648666 precision = self .precision ,
667+ dot_general_cls = dot_general_cls ,
649668 )
650669
651670 self .value = nn .Dense (
@@ -656,6 +675,7 @@ def setup(self):
656675 param_dtype = self .weights_dtype ,
657676 name = "to_v" ,
658677 precision = self .precision ,
678+ dot_general_cls = dot_general_cls ,
659679 )
660680
661681 self .proj_attn = nn .Dense (
@@ -665,6 +685,7 @@ def setup(self):
665685 param_dtype = self .weights_dtype ,
666686 name = "to_out_0" ,
667687 precision = self .precision ,
688+ dot_general_cls = dot_general_cls ,
668689 )
669690 self .dropout_layer = nn .Dropout (rate = self .dropout )
670691
@@ -717,6 +738,7 @@ class FlaxBasicTransformerBlock(nn.Module):
717738 Overrides default block sizes for flash attention.
718739 mesh (`jax.sharding.mesh`, *optional*, defaults to `None`):
719740 jax mesh is required if attention is set to flash.
741+ quant (`AqtQuantization`, *optional*, defaults to None)
720742 """
721743
722744 dim : int
@@ -733,6 +755,7 @@ class FlaxBasicTransformerBlock(nn.Module):
733755 flash_block_sizes : BlockSizes = None
734756 mesh : jax .sharding .Mesh = None
735757 precision : jax .lax .Precision = None
758+ quant : Quant = None
736759
737760 def setup (self ):
738761 # self attention (or cross_attention if only_cross_attention is True)
@@ -750,6 +773,7 @@ def setup(self):
750773 dtype = self .dtype ,
751774 weights_dtype = self .weights_dtype ,
752775 precision = self .precision ,
776+ quant = self .quant ,
753777 )
754778 # cross attention
755779 self .attn2 = FlaxAttention (
@@ -766,6 +790,7 @@ def setup(self):
766790 dtype = self .dtype ,
767791 weights_dtype = self .weights_dtype ,
768792 precision = self .precision ,
793+ quant = self .quant ,
769794 )
770795 self .ff = FlaxFeedForward (
771796 dim = self .dim , dropout = self .dropout , dtype = self .dtype , weights_dtype = self .weights_dtype , precision = self .precision
@@ -838,6 +863,8 @@ class FlaxTransformer2DModel(nn.Module):
838863 Overrides default block sizes for flash attention.
839864 mesh (`jax.sharding.mesh`, *optional*, defaults to `None`):
840865 jax mesh is required if attention is set to flash.
866+ quant (`AqtQuantization`, *optional*, defaults to None)
867+ Configures AQT quantization github.com/google/aqt.
841868 """
842869
843870 in_channels : int
@@ -858,6 +885,7 @@ class FlaxTransformer2DModel(nn.Module):
858885 norm_num_groups : int = 32
859886 precision : jax .lax .Precision = None
860887 hidden_state_axis_names : AxisNames = (BATCH , LENGTH , D_KV )
888+ quant : Quant = (None ,)
861889
862890 def setup (self ):
863891 self .norm = nn .GroupNorm (num_groups = self .norm_num_groups , epsilon = 1e-5 , dtype = self .dtype , param_dtype = self .weights_dtype )
@@ -903,6 +931,7 @@ def setup(self):
903931 flash_block_sizes = self .flash_block_sizes ,
904932 mesh = self .mesh ,
905933 precision = self .precision ,
934+ quant = self .quant ,
906935 )
907936 for _ in range (self .depth )
908937 ]
0 commit comments