Skip to content

Commit 1c9d4c1

Browse files
ksikiricjfacevedo-googleentrpn
authored
Flash attention for GPUs like in maxtext (#149)
* add support for flux vae. ~ wip * test for flux vae both encoding and decoding. * add clip text encoder test * remove transformers inside maxdiffusion, add transformers dependency. Start creating generation code for flux. * add double block to flux * forward pass for single double block. * trying to use scan. * add single stream block * finish transformer * convert pt weights to flax and load transformer state. * apply fsdp sharding, do one forward pass in the transformer. * wip - generate fn * working loop, bad generation * e2e, encoder offloading. * support both dev and schnell loading. Images still incorrect. * flux schnell working * removed unused code. * support dev * add sentencepiece requirement * fix repeated double and single blocks. * optimized flash block sizes for trillium. * clean up code and lint * fix sdxl generate smoke tests. * fix rest of unit tests. * update readme and some dependencies. * remove unused dependencies. * initial lora implementation for flux * adding another format lora support. * Support other format loras. update readme. Run code_style. * ruff * fix typo in readme. * Added FA support for GPUs * ruff and code_style * fixed final comments * Correcting small misstake due to missunderstanding --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com> Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
1 parent 41e901c commit 1c9d4c1

6 files changed

Lines changed: 108 additions & 26 deletions

File tree

README.md

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,23 @@ MaxDiffusion supports
4444

4545
# Table of Contents
4646

47-
* [Getting Started](#getting-started)
48-
* [Training](#training)
49-
* [Dreambooth](#dreambooth)
50-
* [Inference](#inference)
51-
* [Flux](#flux)
52-
* [Flux LoRA](#flux-lora)
53-
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
54-
* [Load Multiple LoRA](#load-multiple-lora)
55-
* [SDXL Lightning](#sdxl-lightning)
56-
* [ControlNet](#controlnet)
57-
* [Comparison To Alternatives](#comparison-to-alternatives)
58-
* [Development](#development)
47+
- [What's new?](#whats-new)
48+
- [Overview](#overview)
49+
- [Table of Contents](#table-of-contents)
50+
- [Getting Started](#getting-started)
51+
- [Getting Started:](#getting-started-1)
52+
- [Training](#training)
53+
- [Dreambooth](#dreambooth)
54+
- [Inference](#inference)
55+
- [Flux](#flux)
56+
- [Flash Attention for GPU:](#flash-attention-for-gpu)
57+
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
58+
- [Load Multiple LoRA](#load-multiple-lora)
59+
- [SDXL Lightning](#sdxl-lightning)
60+
- [ControlNet](#controlnet)
61+
- [Getting Started: Multihost development](#getting-started-multihost-development)
62+
- [Comparison to Alternatives](#comparison-to-alternatives)
63+
- [Development](#development)
5964

6065
# Getting Started
6166

@@ -171,6 +176,24 @@ To generate images, run the following command:
171176
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
172177
```
173178

179+
## Flash Attention for GPU:
180+
Flash Attention for GPU is supported via TransformerEngine. Installation instructions:
181+
182+
```bash
183+
cd maxdiffusion
184+
pip install -U "jax[cuda12]"
185+
pip install -r requirements.txt
186+
pip install --upgrade torch torchvision
187+
pip install "transformer_engine[jax]
188+
pip install .
189+
```
190+
191+
Now run the command:
192+
193+
```bash
194+
NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu
195+
```
196+
174197
## Flux LoRA
175198
176199
Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know.

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ precision: "DEFAULT"
5454
# Set true to load weights from pytorch
5555
from_pt: True
5656
split_head_dim: True
57-
attention: 'flash' # Supported attention: dot_product, flash
57+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
5858

5959
flash_block_sizes: {}
6060
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ precision: "DEFAULT"
5353
# Set true to load weights from pytorch
5454
from_pt: True
5555
split_head_dim: True
56-
attention: 'flash' # Supported attention: dot_product, flash
56+
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
5757
flash_block_sizes: {
5858
"block_q" : 256,
5959
"block_kv_compute" : 256,

src/maxdiffusion/generate_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def unpack(x: Array, height: int, width: int) -> Array:
7777

7878

7979
def vae_decode(latents, vae, state, config):
80-
img = unpack(x=latents, height=config.resolution, width=config.resolution)
80+
img = unpack(x=latents.astype(jnp.float32), height=config.resolution, width=config.resolution)
8181
img = img / vae.config.scaling_factor + vae.config.shift_factor
8282
img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample
8383
return img

src/maxdiffusion/models/attention_flax.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,25 @@ class AttentionOp(nn.Module):
5252
flash_block_sizes: BlockSizes = None
5353
dtype: DType = jnp.float32
5454

55+
def setup(self):
56+
if self.attention_kernel == "cudnn_flash_te":
57+
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
58+
59+
self.dpa_layer = DotProductAttention(
60+
head_dim=self.dim_head,
61+
num_attention_heads=self.heads,
62+
num_gqa_groups=self.heads,
63+
attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal'
64+
attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
65+
# attention_dropout=self.dropout_rate,
66+
dropout_rng_name="aqt",
67+
dtype=self.dtype,
68+
# float32_logits=self.float32_logits,
69+
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
70+
scale_factor=self.scale,
71+
transpose_batch_sequence=False,
72+
)
73+
5574
def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None:
5675
"""Check attention inputs."""
5776

@@ -64,16 +83,22 @@ def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None
6483
def apply_attention(self, query: Array, key: Array, value: Array):
6584
"""Routes to different attention kernels."""
6685
self.check_attention_inputs(query, key, value)
67-
can_use_flash_attention = (
68-
query.shape[1] >= self.flash_min_seq_length
69-
and key.shape[1] >= self.flash_min_seq_length
70-
and value.shape[1] >= self.flash_min_seq_length
71-
)
86+
87+
if self.attention_kernel == "flash":
88+
can_use_flash_attention = (
89+
query.shape[1] >= self.flash_min_seq_length
90+
and key.shape[1] >= self.flash_min_seq_length
91+
and value.shape[1] >= self.flash_min_seq_length
92+
)
93+
else:
94+
can_use_flash_attention = True
7295

7396
if self.attention_kernel == "dot_product" or self.use_memory_efficient_attention or not can_use_flash_attention:
7497
return self.apply_attention_dot(query, key, value)
7598
elif self.attention_kernel == "flash":
7699
return self.tpu_flash_attention(query, key * self.scale, value)
100+
elif self.attention_kernel == "cudnn_flash_te":
101+
return self.cudnn_flash_attention(query, key, value)
77102
else:
78103
raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.")
79104

@@ -132,6 +157,32 @@ def wrap_flash_attention(query, key, value):
132157

133158
return x
134159

160+
def cudnn_flash_attention(
161+
self,
162+
query: Array,
163+
key: Array,
164+
value: Array,
165+
) -> Array:
166+
"""CUDNN Flash Attention with Transformer Engine.
167+
1. Stable API, supports GQA
168+
2. Supports head_dim till 128; head_dim=256 support will be added soon
169+
"""
170+
# These imports are only meant to work in a GPU build.
171+
# copied from tpu_flash_attention
172+
query = self.reshape_data_for_cudnn_flash(query)
173+
key = self.reshape_data_for_cudnn_flash(key)
174+
value = self.reshape_data_for_cudnn_flash(value)
175+
176+
cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV)
177+
axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names)
178+
179+
query = nn.with_logical_constraint(query, axis_names)
180+
key = nn.with_logical_constraint(key, axis_names)
181+
value = nn.with_logical_constraint(value, axis_names)
182+
183+
out = self.dpa_layer(query, key, value, mask=None)
184+
return self.reshape_data_from_cudnn_flash(out)
185+
135186
def apply_attention_dot(self, query: Array, key: Array, value: Array):
136187
"""Apply Attention."""
137188
if self.split_head_dim:
@@ -209,6 +260,16 @@ def reshape_batch_dim_to_heads(self, tensor):
209260
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
210261
return tensor
211262

263+
def reshape_data_for_cudnn_flash(self, tensor):
264+
# reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format)
265+
batch, seq, heads_and_dim_head = tensor.shape
266+
tensor = tensor.reshape(batch, seq, self.heads, heads_and_dim_head // self.heads)
267+
return tensor
268+
269+
def reshape_data_from_cudnn_flash(self, tensor):
270+
# reshapes from [b, s, h, d] back to [b, s, h * d]
271+
return tensor.reshape(tensor.shape[0], tensor.shape[1], -1)
272+
212273
def reshape_data_for_flash(self, tensor):
213274
# reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format)
214275
batch, seq, heads_and_dim_head = tensor.shape

src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None):
147147
if hidden_states.dtype == jnp.float16:
148148
hidden_states = jnp.clip(hidden_states, -65504, 65504)
149149

150-
return hidden_states, temb, image_rotary_emb
150+
return hidden_states
151151

152152

153153
class FluxTransformerBlock(nn.Module):
@@ -296,7 +296,7 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=
296296
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
297297
if encoder_hidden_states.dtype == jnp.float16:
298298
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
299-
return hidden_states, encoder_hidden_states, temb, image_rotary_emb
299+
return hidden_states, encoder_hidden_states
300300

301301

302302
@flax_register_to_config
@@ -504,7 +504,7 @@ def __call__(
504504
image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed"))
505505

506506
for double_block in self.double_blocks:
507-
hidden_states, encoder_hidden_states, temb, image_rotary_emb = double_block(
507+
hidden_states, encoder_hidden_states = double_block(
508508
hidden_states=hidden_states,
509509
encoder_hidden_states=encoder_hidden_states,
510510
temb=temb,
@@ -513,9 +513,7 @@ def __call__(
513513
hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1)
514514
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))
515515
for single_block in self.single_blocks:
516-
hidden_states, temb, image_rotary_emb = single_block(
517-
hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
518-
)
516+
hidden_states = single_block(hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb)
519517
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
520518

521519
hidden_states = self.norm_out(hidden_states, temb)

0 commit comments

Comments
 (0)