diff --git a/README.md b/README.md index f9be0f3fc..09492f921 100644 --- a/README.md +++ b/README.md @@ -44,18 +44,23 @@ MaxDiffusion supports # Table of Contents -* [Getting Started](#getting-started) - * [Training](#training) - * [Dreambooth](#dreambooth) - * [Inference](#inference) - * [Flux](#flux) - * [Flux LoRA](#flux-lora) - * [Hyper-SD XL LoRA](#hyper-sdxl-lora) - * [Load Multiple LoRA](#load-multiple-lora) - * [SDXL Lightning](#sdxl-lightning) - * [ControlNet](#controlnet) -* [Comparison To Alternatives](#comparison-to-alternatives) -* [Development](#development) +- [What's new?](#whats-new) +- [Overview](#overview) +- [Table of Contents](#table-of-contents) +- [Getting Started](#getting-started) + - [Getting Started:](#getting-started-1) + - [Training](#training) + - [Dreambooth](#dreambooth) + - [Inference](#inference) + - [Flux](#flux) + - [Flash Attention for GPU:](#flash-attention-for-gpu) + - [Hyper SDXL LoRA](#hyper-sdxl-lora) + - [Load Multiple LoRA](#load-multiple-lora) + - [SDXL Lightning](#sdxl-lightning) + - [ControlNet](#controlnet) + - [Getting Started: Multihost development](#getting-started-multihost-development) +- [Comparison to Alternatives](#comparison-to-alternatives) +- [Development](#development) # Getting Started @@ -171,6 +176,24 @@ To generate images, run the following command: python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False ``` + ## Flash Attention for GPU: + Flash Attention for GPU is supported via TransformerEngine. Installation instructions: + + ```bash + cd maxdiffusion + pip install -U "jax[cuda12]" + pip install -r requirements.txt + pip install --upgrade torch torchvision + pip install "transformer_engine[jax] + pip install . + ``` + + Now run the command: + + ```bash + NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu + ``` + ## Flux LoRA Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know. diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 3077e3b56..53a8cf8e6 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -54,7 +54,7 @@ precision: "DEFAULT" # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te flash_block_sizes: {} # Use the following flash_block_sizes on v6e (Trillium) due to larger vmem. diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index ee9db566a..1b2bc28b3 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -53,7 +53,7 @@ precision: "DEFAULT" # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te flash_block_sizes: { "block_q" : 256, "block_kv_compute" : 256, diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 1c221ee0b..59564a271 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -77,7 +77,7 @@ def unpack(x: Array, height: int, width: int) -> Array: def vae_decode(latents, vae, state, config): - img = unpack(x=latents, height=config.resolution, width=config.resolution) + img = unpack(x=latents.astype(jnp.float32), height=config.resolution, width=config.resolution) img = img / vae.config.scaling_factor + vae.config.shift_factor img = vae.apply({"params": state.params}, img, deterministic=True, method=vae.decode).sample return img diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 884b6d688..db8626984 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -52,6 +52,25 @@ class AttentionOp(nn.Module): flash_block_sizes: BlockSizes = None dtype: DType = jnp.float32 + def setup(self): + if self.attention_kernel == "cudnn_flash_te": + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + + self.dpa_layer = DotProductAttention( + head_dim=self.dim_head, + num_attention_heads=self.heads, + num_gqa_groups=self.heads, + attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + # attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=self.dtype, + # float32_logits=self.float32_logits, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=self.scale, + transpose_batch_sequence=False, + ) + def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None: """Check attention inputs.""" @@ -64,16 +83,22 @@ def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None def apply_attention(self, query: Array, key: Array, value: Array): """Routes to different attention kernels.""" self.check_attention_inputs(query, key, value) - can_use_flash_attention = ( - query.shape[1] >= self.flash_min_seq_length - and key.shape[1] >= self.flash_min_seq_length - and value.shape[1] >= self.flash_min_seq_length - ) + + if self.attention_kernel == "flash": + can_use_flash_attention = ( + query.shape[1] >= self.flash_min_seq_length + and key.shape[1] >= self.flash_min_seq_length + and value.shape[1] >= self.flash_min_seq_length + ) + else: + can_use_flash_attention = True if self.attention_kernel == "dot_product" or self.use_memory_efficient_attention or not can_use_flash_attention: return self.apply_attention_dot(query, key, value) elif self.attention_kernel == "flash": return self.tpu_flash_attention(query, key * self.scale, value) + elif self.attention_kernel == "cudnn_flash_te": + return self.cudnn_flash_attention(query, key, value) else: raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.") @@ -132,6 +157,32 @@ def wrap_flash_attention(query, key, value): return x + def cudnn_flash_attention( + self, + query: Array, + key: Array, + value: Array, + ) -> Array: + """CUDNN Flash Attention with Transformer Engine. + 1. Stable API, supports GQA + 2. Supports head_dim till 128; head_dim=256 support will be added soon + """ + # These imports are only meant to work in a GPU build. + # copied from tpu_flash_attention + query = self.reshape_data_for_cudnn_flash(query) + key = self.reshape_data_for_cudnn_flash(key) + value = self.reshape_data_for_cudnn_flash(value) + + cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV) + axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names) + + query = nn.with_logical_constraint(query, axis_names) + key = nn.with_logical_constraint(key, axis_names) + value = nn.with_logical_constraint(value, axis_names) + + out = self.dpa_layer(query, key, value, mask=None) + return self.reshape_data_from_cudnn_flash(out) + def apply_attention_dot(self, query: Array, key: Array, value: Array): """Apply Attention.""" if self.split_head_dim: @@ -209,6 +260,16 @@ def reshape_batch_dim_to_heads(self, tensor): tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) return tensor + def reshape_data_for_cudnn_flash(self, tensor): + # reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format) + batch, seq, heads_and_dim_head = tensor.shape + tensor = tensor.reshape(batch, seq, self.heads, heads_and_dim_head // self.heads) + return tensor + + def reshape_data_from_cudnn_flash(self, tensor): + # reshapes from [b, s, h, d] back to [b, s, h * d] + return tensor.reshape(tensor.shape[0], tensor.shape[1], -1) + def reshape_data_for_flash(self, tensor): # reshapes from [b, s, h * d] to [b, h, s, d] (input format to flash format) batch, seq, heads_and_dim_head = tensor.shape diff --git a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py index 5035e36e4..96207468e 100644 --- a/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py +++ b/src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py @@ -147,7 +147,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None): if hidden_states.dtype == jnp.float16: hidden_states = jnp.clip(hidden_states, -65504, 65504) - return hidden_states, temb, image_rotary_emb + return hidden_states class FluxTransformerBlock(nn.Module): @@ -296,7 +296,7 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb= encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output if encoder_hidden_states.dtype == jnp.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - return hidden_states, encoder_hidden_states, temb, image_rotary_emb + return hidden_states, encoder_hidden_states @flax_register_to_config @@ -504,7 +504,7 @@ def __call__( image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed")) for double_block in self.double_blocks: - hidden_states, encoder_hidden_states, temb, image_rotary_emb = double_block( + hidden_states, encoder_hidden_states = double_block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, @@ -513,9 +513,7 @@ def __call__( hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1) hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) for single_block in self.single_blocks: - hidden_states, temb, image_rotary_emb = single_block( - hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb - ) + hidden_states = single_block(hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb) hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] hidden_states = self.norm_out(hidden_states, temb)