Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ MaxDiffusion supports
- [Dreambooth](#dreambooth)
- [Inference](#inference)
- [Flux](#flux)
- [Flash Attention for GPU:](#flash-attention-for-gpu)
- [Fused Attention for GPU:](#fused-attention-for-gpu)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason why we name it as fused attention instead of flash attention considering you are using cudnn_flash_te?

- [Hyper SDXL LoRA](#hyper-sdxl-lora)
- [Load Multiple LoRA](#load-multiple-lora)
- [SDXL Lightning](#sdxl-lightning)
Expand Down Expand Up @@ -83,6 +83,14 @@ After installation completes, run the training script.
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_xl_run" output_dir="gs://your-bucket/" per_device_batch_size=1
```

On GPUS with Fused Attention:

First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu).

```bash
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=bfloat16 resolution=512 per_device_batch_size=1 attention="cudnn_flash_te" jit_initializers=False
```
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider jit initializer works, maybe remove this config?


To generate images with a trained checkpoint, run:

```bash
Expand Down Expand Up @@ -176,8 +184,8 @@ To generate images, run the following command:
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
```

## Flash Attention for GPU:
Flash Attention for GPU is supported via TransformerEngine. Installation instructions:
## Fused Attention for GPU:
Fused Attention for GPU is supported via TransformerEngine. Installation instructions:

```bash
cd maxdiffusion
Expand Down
5 changes: 5 additions & 0 deletions src/maxdiffusion/configs/base14.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ activations_dtype: 'bfloat16'
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
# at the cost of time.
precision: "DEFAULT"

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
jit_initializers: True

# Set true to load weights from pytorch
from_pt: False
split_head_dim: True
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ activations_dtype: 'bfloat16'
# at the cost of time.
precision: "DEFAULT"

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
jit_initializers: True

# Set true to load weights from pytorch
from_pt: False
split_head_dim: True
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ activations_dtype: 'bfloat16'
# at the cost of time.
precision: "DEFAULT"

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
jit_initializers: True

# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ activations_dtype: 'bfloat16'
# at the cost of time.
precision: "DEFAULT"

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
jit_initializers: True

# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ activations_dtype: 'bfloat16'
# at the cost of time.
precision: "DEFAULT"

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
jit_initializers: True

# Set true to load weights from pytorch
from_pt: True
split_head_dim: True
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ activations_dtype: 'bfloat16'
# at the cost of time.
precision: "DEFAULT"

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
jit_initializers: True

# Set true to load weights from pytorch
from_pt: False
split_head_dim: True
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ activations_dtype: 'bfloat16'
# at the cost of time.
precision: "DEFAULT"

# if False state is not jitted and instead replicate is called. This is good for debugging on single host
# It must be True for multi-host.
jit_initializers: True

# Set true to load weights from pytorch
from_pt: False
split_head_dim: True
Expand Down
19 changes: 13 additions & 6 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,19 @@ def setup_initial_state(
training=training,
eval_only=False,
)

state = jax.jit(
init_train_state_partial,
in_shardings=None,
out_shardings=state_mesh_shardings,
)()
if config.jit_initializers:
state = jax.jit(
init_train_state_partial,
in_shardings=None,
out_shardings=state_mesh_shardings,
)()
else:
state = init_train_state_partial()
if model_params:
state = state.replace(params=model_params)
state = jax.device_put(state, state_mesh_shardings)
if model_params:
state = state.replace(params=model_params)

state = unbox_logicallypartioned_trainstate(state)

Expand Down
12 changes: 11 additions & 1 deletion src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,17 @@ def cudnn_flash_attention(
key = nn.with_logical_constraint(key, axis_names)
value = nn.with_logical_constraint(value, axis_names)

out = self.dpa_layer(query, key, value, mask=None)
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(axis_names, axis_names, axis_names),
out_specs=axis_names,
check_rep=False,
)
def wrap_flash_attention(query, key, value):
return jax.vmap(self.dpa_layer)(query, key, value, mask=None)

out = wrap_flash_attention(query, key, value)
return self.reshape_data_from_cudnn_flash(out)

def apply_attention_dot(self, query: Array, key: Array, value: Array):
Expand Down
7 changes: 7 additions & 0 deletions src/maxdiffusion/train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,11 @@ def main(argv: Sequence[str]) -> None:


if __name__ == "__main__":
import os
import tensorflow as tf
import torch

tf.config.set_visible_devices([], "GPU")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
torch.set_default_device("cpu")
app.run(main)
25 changes: 13 additions & 12 deletions src/maxdiffusion/trainers/base_stable_diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,6 @@ def start_training(self):
# create train states
train_states = {}
state_shardings = {}
unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self.create_unet_state(
# ambiguous here, but if self.params.get("unet") doesn't exist
# Then its 1 of 2 scenarios:
# 1. unet state will be loaded directly from orbax
# 2. a new unet is being trained from scratch.
pipeline=pipeline,
params=params,
checkpoint_item_name="unet_state",
is_training=True,
)
train_states["unet_state"] = unet_state
state_shardings["unet_state_shardings"] = unet_state_mesh_shardings
vae_state, vae_state_mesh_shardings = self.create_vae_state(
pipeline=pipeline, params=params, checkpoint_item_name="vae_state", is_training=False
)
Expand Down Expand Up @@ -131,6 +119,19 @@ def start_training(self):
if self.config.dataset_type == "grain":
data_iterator = self.restore_data_iterator_state(data_iterator)

unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self.create_unet_state(
# ambiguous here, but if self.params.get("unet") doesn't exist
# Then its 1 of 2 scenarios:
# 1. unet state will be loaded directly from orbax
# 2. a new unet is being trained from scratch.
pipeline=pipeline,
params=params,
checkpoint_item_name="unet_state",
is_training=True,
)
train_states["unet_state"] = unet_state
state_shardings["unet_state_shardings"] = unet_state_mesh_shardings

data_shardings = self.get_data_shardings()
# Compile train_step
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)
Expand Down