Skip to content

Commit 3b4f4d5

Browse files
AQT integration (#156)
1 parent 2228609 commit 3b4f4d5

16 files changed

Lines changed: 373 additions & 4 deletions

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ huggingface_hub==0.24.7
3030
transformers==4.48.1
3131
einops==0.8.0
3232
sentencepiece
33+
aqtp

src/maxdiffusion/configs/base14.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,7 @@ prior_loss_weight: 1.0
216216
num_class_images: 100
217217
# If true, set dataset_save_location.
218218
cache_dreambooth_dataset: False
219+
quantization: ''
220+
# Shard the range finding operation for quantization. By default this is set to number of slices.
221+
quantization_local_shard_count: -1
222+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base21.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,7 @@ prior_loss_weight: 1.0
217217
num_class_images: 100
218218
# If true, set dataset_save_location.
219219
cache_dreambooth_dataset: False
220+
quantization: ''
221+
# Shard the range finding operation for quantization. By default this is set to number of slices.
222+
quantization_local_shard_count: -1
223+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,8 @@ prior_loss_weight: 1.0
231231
num_class_images: 100
232232
# If true, set dataset_save_location.
233233
cache_dreambooth_dataset: False
234+
235+
quantization: ''
236+
# Shard the range finding operation for quantization. By default this is set to number of slices.
237+
quantization_local_shard_count: -1
238+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,8 @@ controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
260260
controlnet_from_pt: True
261261
controlnet_conditioning_scale: 0.5
262262
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
263+
quantization: ''
264+
# Shard the range finding operation for quantization. By default this is set to number of slices.
265+
quantization_local_shard_count: -1
266+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
267+

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,7 @@ controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
268268
controlnet_from_pt: True
269269
controlnet_conditioning_scale: 0.5
270270
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
271+
quantization: ''
272+
# Shard the range finding operation for quantization. By default this is set to number of slices.
273+
quantization_local_shard_count: -1
274+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_xl.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,9 @@ controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
233233
controlnet_from_pt: True
234234
controlnet_conditioning_scale: 0.5
235235
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
236+
enable_mllog: False
237+
238+
quantization: ''
239+
# Shard the range finding operation for quantization. By default this is set to number of slices.
240+
quantization_local_shard_count: -1
241+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_xl_lightning.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,7 @@ lora_config: {
185185
# }
186186

187187
enable_mllog: False
188+
quantization: ''
189+
# Shard the range finding operation for quantization. By default this is set to number of slices.
190+
quantization_local_shard_count: -1
191+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configuration_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ def to_json_saveable(value):
577577
config_dict.pop("mesh", None)
578578
config_dict.pop("precision", None)
579579
config_dict.pop("weights_dtype", None)
580+
config_dict.pop("quant", None)
580581

581582
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
582583

@@ -659,7 +660,7 @@ def init(self, *args, **kwargs):
659660
# ignore flax specific attributes
660661
if field.name in self._flax_internal_args:
661662
continue
662-
if type(field.default) == dataclasses._MISSING_TYPE:
663+
if type(field.default) == dataclasses._MISSING_TYPE: # noqa: E721
663664
default_kwargs[field.name] = None
664665
else:
665666
default_kwargs[field.name] = getattr(self, field.name)

src/maxdiffusion/models/attention_flax.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from einops import rearrange
2525
from .. import common_types, max_logging
2626

27+
from . import quantizations
28+
29+
2730
Array = common_types.Array
2831
Mesh = common_types.Mesh
2932
DType = common_types.DType
@@ -36,6 +39,14 @@
3639
HEAD = common_types.HEAD
3740
D_KV = common_types.D_KV
3841
EMBED = common_types.EMBED
42+
Quant = quantizations.AqtQuantization
43+
44+
45+
Quant = quantizations.AqtQuantization
46+
47+
48+
def _maybe_aqt_einsum(quant: Quant):
49+
return jnp.einsum if quant is None else quant.einsum()
3950

4051

4152
class AttentionOp(nn.Module):
@@ -51,6 +62,7 @@ class AttentionOp(nn.Module):
5162
flash_min_seq_length: int = 4096
5263
flash_block_sizes: BlockSizes = None
5364
dtype: DType = jnp.float32
65+
quant: Quant = None
5466

5567
def setup(self):
5668
if self.attention_kernel == "cudnn_flash_te":
@@ -585,6 +597,7 @@ class FlaxAttention(nn.Module):
585597
jax mesh is required if attention is set to flash.
586598
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
587599
Parameters `dtype`
600+
quant (`AqtQuantization`, *optional*, defaults to None)
588601
589602
"""
590603

@@ -605,6 +618,7 @@ class FlaxAttention(nn.Module):
605618
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
606619
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
607620
precision: jax.lax.Precision = None
621+
quant: Quant = None
608622

609623
def setup(self):
610624

@@ -624,10 +638,13 @@ def setup(self):
624638
split_head_dim=self.split_head_dim,
625639
flash_block_sizes=self.flash_block_sizes,
626640
dtype=self.dtype,
641+
quant=self.quant,
627642
)
628643

629644
qkv_init_kernel = nn.with_logical_partitioning(nn.initializers.lecun_normal(), ("embed", "heads"))
630-
645+
dot_general_cls = None
646+
if self.quant:
647+
dot_general_cls = self.quant.dot_general_cls()
631648
self.query = nn.Dense(
632649
inner_dim,
633650
kernel_init=qkv_init_kernel,
@@ -636,6 +653,7 @@ def setup(self):
636653
param_dtype=self.weights_dtype,
637654
name="to_q",
638655
precision=self.precision,
656+
dot_general_cls=dot_general_cls,
639657
)
640658

641659
self.key = nn.Dense(
@@ -646,6 +664,7 @@ def setup(self):
646664
param_dtype=self.weights_dtype,
647665
name="to_k",
648666
precision=self.precision,
667+
dot_general_cls=dot_general_cls,
649668
)
650669

651670
self.value = nn.Dense(
@@ -656,6 +675,7 @@ def setup(self):
656675
param_dtype=self.weights_dtype,
657676
name="to_v",
658677
precision=self.precision,
678+
dot_general_cls=dot_general_cls,
659679
)
660680

661681
self.proj_attn = nn.Dense(
@@ -665,6 +685,7 @@ def setup(self):
665685
param_dtype=self.weights_dtype,
666686
name="to_out_0",
667687
precision=self.precision,
688+
dot_general_cls=dot_general_cls,
668689
)
669690
self.dropout_layer = nn.Dropout(rate=self.dropout)
670691

@@ -717,6 +738,7 @@ class FlaxBasicTransformerBlock(nn.Module):
717738
Overrides default block sizes for flash attention.
718739
mesh (`jax.sharding.mesh`, *optional*, defaults to `None`):
719740
jax mesh is required if attention is set to flash.
741+
quant (`AqtQuantization`, *optional*, defaults to None)
720742
"""
721743

722744
dim: int
@@ -733,6 +755,7 @@ class FlaxBasicTransformerBlock(nn.Module):
733755
flash_block_sizes: BlockSizes = None
734756
mesh: jax.sharding.Mesh = None
735757
precision: jax.lax.Precision = None
758+
quant: Quant = None
736759

737760
def setup(self):
738761
# self attention (or cross_attention if only_cross_attention is True)
@@ -750,6 +773,7 @@ def setup(self):
750773
dtype=self.dtype,
751774
weights_dtype=self.weights_dtype,
752775
precision=self.precision,
776+
quant=self.quant,
753777
)
754778
# cross attention
755779
self.attn2 = FlaxAttention(
@@ -766,6 +790,7 @@ def setup(self):
766790
dtype=self.dtype,
767791
weights_dtype=self.weights_dtype,
768792
precision=self.precision,
793+
quant=self.quant,
769794
)
770795
self.ff = FlaxFeedForward(
771796
dim=self.dim, dropout=self.dropout, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision
@@ -838,6 +863,8 @@ class FlaxTransformer2DModel(nn.Module):
838863
Overrides default block sizes for flash attention.
839864
mesh (`jax.sharding.mesh`, *optional*, defaults to `None`):
840865
jax mesh is required if attention is set to flash.
866+
quant (`AqtQuantization`, *optional*, defaults to None)
867+
Configures AQT quantization github.com/google/aqt.
841868
"""
842869

843870
in_channels: int
@@ -858,6 +885,7 @@ class FlaxTransformer2DModel(nn.Module):
858885
norm_num_groups: int = 32
859886
precision: jax.lax.Precision = None
860887
hidden_state_axis_names: AxisNames = (BATCH, LENGTH, D_KV)
888+
quant: Quant = (None,)
861889

862890
def setup(self):
863891
self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype)
@@ -903,6 +931,7 @@ def setup(self):
903931
flash_block_sizes=self.flash_block_sizes,
904932
mesh=self.mesh,
905933
precision=self.precision,
934+
quant=self.quant,
906935
)
907936
for _ in range(self.depth)
908937
]

0 commit comments

Comments
 (0)