Skip to content

Commit 350026b

Browse files
committed
Added FA support for GPUs
1 parent 63bad83 commit 350026b

4 files changed

Lines changed: 72 additions & 8 deletions

File tree

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ MaxDiffusion supports
5353
- [Dreambooth](#dreambooth)
5454
- [Inference](#inference)
5555
- [Flux](#flux)
56+
- [Flash Attention for GPU:](#flash-attention-for-gpu)
5657
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
5758
- [Load Multiple LoRA](#load-multiple-lora)
5859
- [SDXL Lightning](#sdxl-lightning)
@@ -175,6 +176,9 @@ To generate images, run the following command:
175176
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
176177
```
177178

179+
### Flash Attention for GPU:
180+
Flash Attention for GPU is supported via TransformerEngine, make sure it is installed and then specify attention=cudnn_flash_te when running the above commands.
181+
178182
## Flux LoRA
179183

180184
Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know.

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ precision: "DEFAULT"
5454
# Set true to load weights from pytorch
5555
from_pt: True
5656
split_head_dim: True
57-
attention: 'flash' # Supported attention: dot_product, flash
57+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
5858

5959
flash_block_sizes: {}
6060
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
@@ -197,7 +197,7 @@ max_train_steps: 200
197197
num_train_epochs: 1
198198
seed: 0
199199
output_dir: 'sdxl-model-finetuned'
200-
per_device_batch_size: 1
200+
per_device_batch_size: 8
201201

202202
warmup_steps_fraction: 0.0
203203
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ precision: "DEFAULT"
5353
# Set true to load weights from pytorch
5454
from_pt: True
5555
split_head_dim: True
56-
attention: 'flash' # Supported attention: dot_product, flash
56+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
5757
flash_block_sizes: {
5858
"block_q" : 256,
5959
"block_kv_compute" : 256,

src/maxdiffusion/models/attention_flax.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,24 @@ class AttentionOp(nn.Module):
5252
flash_block_sizes: BlockSizes = None
5353
dtype: DType = jnp.float32
5454

55+
def setup(self):
56+
if self.attention_kernel == "cudnn_flash_te":
57+
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
58+
self.dpa_layer = DotProductAttention(
59+
head_dim=self.dim_head,
60+
num_attention_heads=self.heads,
61+
num_gqa_groups=self.heads,
62+
attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal'
63+
attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
64+
# attention_dropout=self.dropout_rate,
65+
dropout_rng_name="aqt",
66+
dtype=self.dtype,
67+
# float32_logits=self.float32_logits,
68+
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
69+
scale_factor=self.scale,
70+
transpose_batch_sequence=False,
71+
)
72+
5573
def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None:
5674
"""Check attention inputs."""
5775

@@ -64,16 +82,22 @@ def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None
6482
def apply_attention(self, query: Array, key: Array, value: Array):
6583
"""Routes to different attention kernels."""
6684
self.check_attention_inputs(query, key, value)
67-
can_use_flash_attention = (
68-
query.shape[1] >= self.flash_min_seq_length
69-
and key.shape[1] >= self.flash_min_seq_length
70-
and value.shape[1] >= self.flash_min_seq_length
71-
)
85+
86+
if self.attention_kernel == "flash":
87+
can_use_flash_attention = (
88+
query.shape[1] >= self.flash_min_seq_length
89+
and key.shape[1] >= self.flash_min_seq_length
90+
and value.shape[1] >= self.flash_min_seq_length
91+
)
92+
else:
93+
can_use_flash_attention = True
7294

7395
if self.attention_kernel == "dot_product" or self.use_memory_efficient_attention or not can_use_flash_attention:
7496
return self.apply_attention_dot(query, key, value)
7597
elif self.attention_kernel == "flash":
7698
return self.tpu_flash_attention(query, key * self.scale, value)
99+
elif self.attention_kernel == "cudnn_flash_te":
100+
return self.cudnn_flash_attention(query, key, value)
77101
else:
78102
raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.")
79103

@@ -132,6 +156,32 @@ def wrap_flash_attention(query, key, value):
132156

133157
return x
134158

159+
def cudnn_flash_attention(
160+
self,
161+
query: Array,
162+
key: Array,
163+
value: Array,
164+
) -> Array:
165+
"""CUDNN Flash Attention with Transformer Engine.
166+
1. Stable API, supports GQA
167+
2. Supports head_dim till 128; head_dim=256 support will be added soon
168+
"""
169+
# These imports are only meant to work in a GPU build.
170+
# copied from tpu_flash_attention
171+
query = self.reshape_data_for_cudnn_flash(query)
172+
key = self.reshape_data_for_cudnn_flash(key)
173+
value = self.reshape_data_for_cudnn_flash(value)
174+
175+
cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV)
176+
axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names)
177+
178+
query = nn.with_logical_constraint(query, axis_names)
179+
key = nn.with_logical_constraint(key, axis_names)
180+
value = nn.with_logical_constraint(value, axis_names)
181+
182+
out = self.dpa_layer(query, key, value, mask=None)
183+
return self.reshape_data_from_cudnn_flash(out)
184+
135185
def apply_attention_dot(self, query: Array, key: Array, value: Array):
136186
"""Apply Attention."""
137187
if self.split_head_dim:
@@ -209,6 +259,16 @@ def reshape_batch_dim_to_heads(self, tensor):
209259
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
210260
return tensor
211261

262+
def reshape_data_for_cudnn_flash(self, tensor):
263+
# reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format)
264+
batch, seq, heads_and_dim_head = tensor.shape
265+
tensor = tensor.reshape(batch, seq, self.heads, heads_and_dim_head // self.heads)
266+
return tensor
267+
268+
def reshape_data_from_cudnn_flash(self, tensor):
269+
# reshapes from [b, s, h, d] back to [b, s, h * d]
270+
return tensor.reshape(tensor.shape[0], tensor.shape[1], -1)
271+
212272
def reshape_data_for_flash(self, tensor):
213273
# reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format)
214274
batch, seq, heads_and_dim_head = tensor.shape

0 commit comments

Comments
 (0)