Skip to content

Commit 7b6ee9d

Browse files
committed
Added FA support for GPUs
1 parent f56234e commit 7b6ee9d

2 files changed

Lines changed: 68 additions & 8 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ precision: "DEFAULT"
5454
# Set true to load weights from pytorch
5555
from_pt: True
5656
split_head_dim: True
57-
attention: 'dot_product' # Supported attention: dot_product, flash
57+
attention: 'cudnn_flash_te' # Supported attention: dot_product, flash, cudnn_flash_te
5858

5959
flash_block_sizes: {}
6060
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
@@ -193,11 +193,11 @@ learning_rate: 1.e-5
193193
scale_lr: False
194194
max_train_samples: -1
195195
# max_train_steps takes priority over num_train_epochs.
196-
max_train_steps: 1500
196+
max_train_steps: 50
197197
num_train_epochs: 1
198198
seed: 0
199199
output_dir: 'sdxl-model-finetuned'
200-
per_device_batch_size: 1
200+
per_device_batch_size: 8
201201

202202
warmup_steps_fraction: 0.1
203203
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

src/maxdiffusion/models/attention_flax.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,24 @@ class AttentionOp(nn.Module):
5252
flash_block_sizes: BlockSizes = None
5353
dtype: DType = jnp.float32
5454

55+
def setup(self):
56+
if self.attention_kernel == "cudnn_flash_te":
57+
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
58+
self.dpa_layer = DotProductAttention(
59+
head_dim=self.dim_head,
60+
num_attention_heads=self.heads,
61+
num_gqa_groups=self.heads,
62+
attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal'
63+
attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
64+
# attention_dropout=self.dropout_rate,
65+
dropout_rng_name="aqt",
66+
dtype=self.dtype,
67+
# float32_logits=self.float32_logits,
68+
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
69+
scale_factor=self.scale,
70+
transpose_batch_sequence=False,
71+
)
72+
5573
def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None:
5674
"""Check attention inputs."""
5775

@@ -64,16 +82,22 @@ def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None
6482
def apply_attention(self, query: Array, key: Array, value: Array):
6583
"""Routes to different attention kernels."""
6684
self.check_attention_inputs(query, key, value)
67-
can_use_flash_attention = (
68-
query.shape[1] >= self.flash_min_seq_length
69-
and key.shape[1] >= self.flash_min_seq_length
70-
and value.shape[1] >= self.flash_min_seq_length
71-
)
85+
86+
if self.attention_kernel == "flash":
87+
can_use_flash_attention = (
88+
query.shape[1] >= self.flash_min_seq_length
89+
and key.shape[1] >= self.flash_min_seq_length
90+
and value.shape[1] >= self.flash_min_seq_length
91+
)
92+
else:
93+
can_use_flash_attention = True
7294

7395
if self.attention_kernel == "dot_product" or self.use_memory_efficient_attention or not can_use_flash_attention:
7496
return self.apply_attention_dot(query, key, value)
7597
elif self.attention_kernel == "flash":
7698
return self.tpu_flash_attention(query, key * self.scale, value)
99+
elif self.attention_kernel == "cudnn_flash_te":
100+
return self.cudnn_flash_attention(query, key, value)
77101
else:
78102
raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.")
79103

@@ -132,6 +156,32 @@ def wrap_flash_attention(query, key, value):
132156

133157
return x
134158

159+
def cudnn_flash_attention(
160+
self,
161+
query: Array,
162+
key: Array,
163+
value: Array,
164+
) -> Array:
165+
"""CUDNN Flash Attention with Transformer Engine.
166+
1. Stable API, supports GQA
167+
2. Supports head_dim till 128; head_dim=256 support will be added soon
168+
"""
169+
# These imports are only meant to work in a GPU build.
170+
# copied from tpu_flash_attention
171+
query = self.reshape_data_for_cudnn_flash(query)
172+
key = self.reshape_data_for_cudnn_flash(key)
173+
value = self.reshape_data_for_cudnn_flash(value)
174+
175+
cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV)
176+
axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names)
177+
178+
query = nn.with_logical_constraint(query, axis_names)
179+
key = nn.with_logical_constraint(key, axis_names)
180+
value = nn.with_logical_constraint(value, axis_names)
181+
182+
out = self.dpa_layer(query, key, value, mask=None)
183+
return self.reshape_data_from_cudnn_flash(out)
184+
135185
def apply_attention_dot(self, query: Array, key: Array, value: Array):
136186
"""Apply Attention."""
137187
if self.split_head_dim:
@@ -209,6 +259,16 @@ def reshape_batch_dim_to_heads(self, tensor):
209259
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
210260
return tensor
211261

262+
def reshape_data_for_cudnn_flash(self, tensor):
263+
# reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format)
264+
batch, seq, heads_and_dim_head = tensor.shape
265+
tensor = tensor.reshape(batch, seq, self.heads, heads_and_dim_head // self.heads)
266+
return tensor
267+
268+
def reshape_data_from_cudnn_flash(self, tensor):
269+
# reshapes from [b, s, h, d] back to [b, s, h * d]
270+
return tensor.reshape(tensor.shape[0], tensor.shape[1], -1)
271+
212272
def reshape_data_for_flash(self, tensor):
213273
# reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format)
214274
batch, seq, heads_and_dim_head = tensor.shape

0 commit comments

Comments
 (0)