Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
bc6cd42
add support for flux vae. ~ wip
jfacevedo-google Jan 14, 2025
5f56257
test for flux vae both encoding and decoding.
jfacevedo-google Jan 14, 2025
c7829d1
add clip text encoder test
jfacevedo-google Jan 15, 2025
572f20d
remove transformers inside maxdiffusion, add transformers dependency.…
jfacevedo-google Jan 22, 2025
ff04543
add double block to flux
jfacevedo-google Jan 22, 2025
8a0ede4
forward pass for single double block.
jfacevedo-google Jan 22, 2025
9fe42ba
trying to use scan.
jfacevedo-google Jan 23, 2025
7e79e05
add single stream block
jfacevedo-google Jan 24, 2025
6641dda
finish transformer
jfacevedo-google Jan 29, 2025
d37a278
convert pt weights to flax and load transformer state.
jfacevedo-google Jan 30, 2025
bb91e8e
apply fsdp sharding, do one forward pass in the transformer.
jfacevedo-google Jan 30, 2025
dfe1089
wip - generate fn
jfacevedo-google Jan 30, 2025
cbc7723
working loop, bad generation
jfacevedo-google Jan 30, 2025
ac14a4b
e2e, encoder offloading.
jfacevedo-google Jan 30, 2025
1c8ed7b
support both dev and schnell loading. Images still incorrect.
jfacevedo-google Feb 1, 2025
c8196ed
flux schnell working
jfacevedo-google Feb 3, 2025
1f1475d
removed unused code.
jfacevedo-google Feb 3, 2025
b49695a
support dev
jfacevedo-google Feb 3, 2025
04377df
add sentencepiece requirement
jfacevedo-google Feb 4, 2025
f6c25e4
fix repeated double and single blocks.
jfacevedo-google Feb 4, 2025
ff24ee1
optimized flash block sizes for trillium.
jfacevedo-google Feb 4, 2025
18250c5
clean up code and lint
jfacevedo-google Feb 4, 2025
303e82a
fix sdxl generate smoke tests.
jfacevedo-google Feb 5, 2025
5df1f3c
fix rest of unit tests.
jfacevedo-google Feb 5, 2025
1ec459d
update readme and some dependencies.
entrpn Feb 5, 2025
1f28cb5
remove unused dependencies.
entrpn Feb 5, 2025
ff16ba6
initial lora implementation for flux
jfacevedo-google Feb 6, 2025
719e6db
adding another format lora support.
jfacevedo-google Feb 12, 2025
91d7f5c
Support other format loras. update readme. Run code_style.
jfacevedo-google Feb 13, 2025
1e01c67
ruff
jfacevedo-google Feb 13, 2025
19e1b8a
fix typo in readme.
entrpn Feb 5, 2025
f05a7be
Added FA support for GPUs
ksikiric Feb 13, 2025
3141d69
ruff and code_style
ksikiric Feb 18, 2025
2072f7e
fixed final comments
ksikiric Feb 20, 2025
56771dd
Correcting small misstake due to missunderstanding
ksikiric Feb 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,23 @@ MaxDiffusion supports

# Table of Contents

* [Getting Started](#getting-started)
* [Training](#training)
* [Dreambooth](#dreambooth)
* [Inference](#inference)
* [Flux](#flux)
* [Flux LoRA](#flux-lora)
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
* [Load Multiple LoRA](#load-multiple-lora)
* [SDXL Lightning](#sdxl-lightning)
* [ControlNet](#controlnet)
* [Comparison To Alternatives](#comparison-to-alternatives)
* [Development](#development)
- [What's new?](#whats-new)
- [Overview](#overview)
- [Table of Contents](#table-of-contents)
- [Getting Started](#getting-started)
- [Getting Started:](#getting-started-1)
- [Training](#training)
- [Dreambooth](#dreambooth)
- [Inference](#inference)
- [Flux](#flux)
- [Flash Attention for GPU:](#flash-attention-for-gpu)
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
- [Load Multiple LoRA](#load-multiple-lora)
- [SDXL Lightning](#sdxl-lightning)
- [ControlNet](#controlnet)
- [Getting Started: Multihost development](#getting-started-multihost-development)
- [Comparison to Alternatives](#comparison-to-alternatives)
- [Development](#development)

# Getting Started

Expand Down Expand Up @@ -171,6 +176,24 @@ To generate images, run the following command:
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
```

## Flash Attention for GPU:
Flash Attention for GPU is supported via TransformerEngine. Installation instructions:

```bash
cd maxdiffusion
pip install -U "jax[cuda12]"
pip install -r requirements.txt
pip install --upgrade torch torchvision
pip install "transformer_engine[jax]
pip install .
```

Now run the command:

```bash
NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu
```

## Flux LoRA

Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know.
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ precision: "DEFAULT"
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te

flash_block_sizes: {}
# Use the following flash_block_sizes on v6e (Trillium) due to larger vmem.
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ precision: "DEFAULT"
# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
flash_block_sizes: {
"block_q" : 256,
"block_kv_compute" : 256,
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/generate_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def unpack(x: Array, height: int, width: int) -> Array:


def vae_decode(latents, vae, state, config):
img = unpack(x=latents, height=config.resolution, width=config.resolution)
img = unpack(x=latents.astype(jnp.float32), height=config.resolution, width=config.resolution)
img = img / vae.config.scaling_factor + vae.config.shift_factor
img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample
return img
Expand Down
71 changes: 66 additions & 5 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,25 @@ class AttentionOp(nn.Module):
flash_block_sizes: BlockSizes = None
dtype: DType = jnp.float32

def setup(self):
if self.attention_kernel == "cudnn_flash_te":
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error

self.dpa_layer = DotProductAttention(
head_dim=self.dim_head,
num_attention_heads=self.heads,
num_gqa_groups=self.heads,
attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal'
attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
# attention_dropout=self.dropout_rate,
dropout_rng_name="aqt",
dtype=self.dtype,
# float32_logits=self.float32_logits,
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=self.scale,
transpose_batch_sequence=False,
)

def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None:
"""Check attention inputs."""

Expand All @@ -64,16 +83,22 @@ def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None
def apply_attention(self, query: Array, key: Array, value: Array):
"""Routes to different attention kernels."""
self.check_attention_inputs(query, key, value)
can_use_flash_attention = (
query.shape[1] >= self.flash_min_seq_length
and key.shape[1] >= self.flash_min_seq_length
and value.shape[1] >= self.flash_min_seq_length
)

if self.attention_kernel == "flash":
can_use_flash_attention = (
query.shape[1] >= self.flash_min_seq_length
and key.shape[1] >= self.flash_min_seq_length
and value.shape[1] >= self.flash_min_seq_length
)
else:
can_use_flash_attention = True

if self.attention_kernel == "dot_product" or self.use_memory_efficient_attention or not can_use_flash_attention:
return self.apply_attention_dot(query, key, value)
elif self.attention_kernel == "flash":
return self.tpu_flash_attention(query, key * self.scale, value)
elif self.attention_kernel == "cudnn_flash_te":
return self.cudnn_flash_attention(query, key, value)
else:
raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.")

Expand Down Expand Up @@ -132,6 +157,32 @@ def wrap_flash_attention(query, key, value):

return x

def cudnn_flash_attention(
self,
query: Array,
key: Array,
value: Array,
) -> Array:
"""CUDNN Flash Attention with Transformer Engine.
1. Stable API, supports GQA
2. Supports head_dim till 128; head_dim=256 support will be added soon
"""
# These imports are only meant to work in a GPU build.
# copied from tpu_flash_attention
query = self.reshape_data_for_cudnn_flash(query)
key = self.reshape_data_for_cudnn_flash(key)
value = self.reshape_data_for_cudnn_flash(value)

cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV)
axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names)

query = nn.with_logical_constraint(query, axis_names)
key = nn.with_logical_constraint(key, axis_names)
value = nn.with_logical_constraint(value, axis_names)

out = self.dpa_layer(query, key, value, mask=None)
return self.reshape_data_from_cudnn_flash(out)

def apply_attention_dot(self, query: Array, key: Array, value: Array):
"""Apply Attention."""
if self.split_head_dim:
Expand Down Expand Up @@ -209,6 +260,16 @@ def reshape_batch_dim_to_heads(self, tensor):
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor

def reshape_data_for_cudnn_flash(self, tensor):
# reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format)
batch, seq, heads_and_dim_head = tensor.shape
tensor = tensor.reshape(batch, seq, self.heads, heads_and_dim_head // self.heads)
return tensor

def reshape_data_from_cudnn_flash(self, tensor):
# reshapes from [b, s, h, d] back to [b, s, h * d]
return tensor.reshape(tensor.shape[0], tensor.shape[1], -1)

def reshape_data_for_flash(self, tensor):
# reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format)
batch, seq, heads_and_dim_head = tensor.shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None):
if hidden_states.dtype == jnp.float16:
hidden_states = jnp.clip(hidden_states, -65504, 65504)

return hidden_states, temb, image_rotary_emb
return hidden_states


class FluxTransformerBlock(nn.Module):
Expand Down Expand Up @@ -296,7 +296,7 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
if encoder_hidden_states.dtype == jnp.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return hidden_states, encoder_hidden_states, temb, image_rotary_emb
return hidden_states, encoder_hidden_states


@flax_register_to_config
Expand Down Expand Up @@ -504,7 +504,7 @@ def __call__(
image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed"))

for double_block in self.double_blocks:
hidden_states, encoder_hidden_states, temb, image_rotary_emb = double_block(
hidden_states, encoder_hidden_states = double_block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
Expand All @@ -513,9 +513,7 @@ def __call__(
hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1)
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))
for single_block in self.single_blocks:
hidden_states, temb, image_rotary_emb = single_block(
hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
)
hidden_states = single_block(hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

hidden_states = self.norm_out(hidden_states, temb)
Expand Down