diff --git a/README.md b/README.md index 09492f921..ac6c9c3d7 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ MaxDiffusion supports - [Dreambooth](#dreambooth) - [Inference](#inference) - [Flux](#flux) - - [Flash Attention for GPU:](#flash-attention-for-gpu) + - [Fused Attention for GPU:](#fused-attention-for-gpu) - [Hyper SDXL LoRA](#hyper-sdxl-lora) - [Load Multiple LoRA](#load-multiple-lora) - [SDXL Lightning](#sdxl-lightning) @@ -83,6 +83,14 @@ After installation completes, run the training script. python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_xl_run" output_dir="gs://your-bucket/" per_device_batch_size=1 ``` + On GPUS with Fused Attention: + + First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu). + + ```bash + NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=bfloat16 resolution=512 per_device_batch_size=1 attention="cudnn_flash_te" jit_initializers=False + ``` + To generate images with a trained checkpoint, run: ```bash @@ -176,8 +184,8 @@ To generate images, run the following command: python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False ``` - ## Flash Attention for GPU: - Flash Attention for GPU is supported via TransformerEngine. Installation instructions: + ## Fused Attention for GPU: + Fused Attention for GPU is supported via TransformerEngine. Installation instructions: ```bash cd maxdiffusion diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index bf3bc8569..1768bbede 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -37,6 +37,11 @@ activations_dtype: 'bfloat16' # fp32 activations and fp32 weights with HIGHEST will provide the best precision # at the cost of time. precision: "DEFAULT" + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: False split_head_dim: True diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index 13cbd80c5..4ff025c47 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -37,6 +37,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: False split_head_dim: True diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index f648a4464..4cf66f5d8 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -38,6 +38,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: True split_head_dim: True diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 97060eed3..944153d64 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -51,6 +51,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: True split_head_dim: True diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index bafac7ce9..3106255a9 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -50,6 +50,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: True split_head_dim: True diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index d23cceab0..24cfe3997 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -38,6 +38,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: False split_head_dim: True diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index c9eba9df1..60f6fb873 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -36,6 +36,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: False split_head_dim: True diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 93acb03e7..2d37e9416 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -413,12 +413,19 @@ def setup_initial_state( training=training, eval_only=False, ) - - state = jax.jit( - init_train_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )() + if config.jit_initializers: + state = jax.jit( + init_train_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + else: + state = init_train_state_partial() + if model_params: + state = state.replace(params=model_params) + state = jax.device_put(state, state_mesh_shardings) + if model_params: + state = state.replace(params=model_params) state = unbox_logicallypartioned_trainstate(state) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index b434fe1e8..839903406 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -192,7 +192,17 @@ def cudnn_flash_attention( key = nn.with_logical_constraint(key, axis_names) value = nn.with_logical_constraint(value, axis_names) - out = self.dpa_layer(query, key, value, mask=None) + @functools.partial( + shard_map.shard_map, + mesh=self.mesh, + in_specs=(axis_names, axis_names, axis_names), + out_specs=axis_names, + check_rep=False, + ) + def wrap_flash_attention(query, key, value): + return jax.vmap(self.dpa_layer)(query, key, value, mask=None) + + out = wrap_flash_attention(query, key, value) return self.reshape_data_from_cudnn_flash(out) def apply_attention_dot(self, query: Array, key: Array, value: Array): diff --git a/src/maxdiffusion/train_sdxl.py b/src/maxdiffusion/train_sdxl.py index 6f074da97..cd8021556 100644 --- a/src/maxdiffusion/train_sdxl.py +++ b/src/maxdiffusion/train_sdxl.py @@ -46,4 +46,11 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": + import os + import tensorflow as tf + import torch + + tf.config.set_visible_devices([], "GPU") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + torch.set_default_device("cpu") app.run(main) diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index f6cecb5c3..d22867e45 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -83,18 +83,6 @@ def start_training(self): # create train states train_states = {} state_shardings = {} - unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self.create_unet_state( - # ambiguous here, but if self.params.get("unet") doesn't exist - # Then its 1 of 2 scenarios: - # 1. unet state will be loaded directly from orbax - # 2. a new unet is being trained from scratch. - pipeline=pipeline, - params=params, - checkpoint_item_name="unet_state", - is_training=True, - ) - train_states["unet_state"] = unet_state - state_shardings["unet_state_shardings"] = unet_state_mesh_shardings vae_state, vae_state_mesh_shardings = self.create_vae_state( pipeline=pipeline, params=params, checkpoint_item_name="vae_state", is_training=False ) @@ -131,6 +119,19 @@ def start_training(self): if self.config.dataset_type == "grain": data_iterator = self.restore_data_iterator_state(data_iterator) + unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self.create_unet_state( + # ambiguous here, but if self.params.get("unet") doesn't exist + # Then its 1 of 2 scenarios: + # 1. unet state will be loaded directly from orbax + # 2. a new unet is being trained from scratch. + pipeline=pipeline, + params=params, + checkpoint_item_name="unet_state", + is_training=True, + ) + train_states["unet_state"] = unet_state + state_shardings["unet_state_shardings"] = unet_state_mesh_shardings + data_shardings = self.get_data_shardings() # Compile train_step p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)