From 6f536d89cb27d09f8c6ca0b94bcb674f761265f7 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 4 Mar 2025 19:13:34 +0000 Subject: [PATCH 1/6] tf processing on cpu --- src/maxdiffusion/max_utils.py | 23 ++++++++--------- src/maxdiffusion/train_sdxl.py | 6 +++++ .../trainers/base_stable_diffusion_trainer.py | 25 ++++++++++--------- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 93acb03e7..2dbcbaff1 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -404,21 +404,18 @@ def setup_initial_state( state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") - init_train_state_partial = functools.partial( - init_train_state, - model=model, - tx=tx, - weights_init_fn=weights_init_fn, - params=model_params, - training=training, - eval_only=False, + state = init_train_state( + model=model, + tx=tx, + weights_init_fn=weights_init_fn, + params=model_params, + training=training, + eval_only=False ) + if model_params: + state = state.replace(params=model_params) - state = jax.jit( - init_train_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )() + state = jax.device_put(state, state_mesh_shardings) state = unbox_logicallypartioned_trainstate(state) diff --git a/src/maxdiffusion/train_sdxl.py b/src/maxdiffusion/train_sdxl.py index 6f074da97..5380659a8 100644 --- a/src/maxdiffusion/train_sdxl.py +++ b/src/maxdiffusion/train_sdxl.py @@ -46,4 +46,10 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": + import os + import tensorflow as tf + import torch + tf.config.set_visible_devices([], "GPU") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + torch.set_default_device("cpu") app.run(main) diff --git a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py index f6cecb5c3..d22867e45 100644 --- a/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py +++ b/src/maxdiffusion/trainers/base_stable_diffusion_trainer.py @@ -83,18 +83,6 @@ def start_training(self): # create train states train_states = {} state_shardings = {} - unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self.create_unet_state( - # ambiguous here, but if self.params.get("unet") doesn't exist - # Then its 1 of 2 scenarios: - # 1. unet state will be loaded directly from orbax - # 2. a new unet is being trained from scratch. - pipeline=pipeline, - params=params, - checkpoint_item_name="unet_state", - is_training=True, - ) - train_states["unet_state"] = unet_state - state_shardings["unet_state_shardings"] = unet_state_mesh_shardings vae_state, vae_state_mesh_shardings = self.create_vae_state( pipeline=pipeline, params=params, checkpoint_item_name="vae_state", is_training=False ) @@ -131,6 +119,19 @@ def start_training(self): if self.config.dataset_type == "grain": data_iterator = self.restore_data_iterator_state(data_iterator) + unet_state, unet_state_mesh_shardings, unet_learning_rate_scheduler = self.create_unet_state( + # ambiguous here, but if self.params.get("unet") doesn't exist + # Then its 1 of 2 scenarios: + # 1. unet state will be loaded directly from orbax + # 2. a new unet is being trained from scratch. + pipeline=pipeline, + params=params, + checkpoint_item_name="unet_state", + is_training=True, + ) + train_states["unet_state"] = unet_state + state_shardings["unet_state_shardings"] = unet_state_mesh_shardings + data_shardings = self.get_data_shardings() # Compile train_step p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings) From b11767b12d04e4c7c45ea1d41a6df87c4effb18e Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Tue, 4 Mar 2025 20:55:00 +0000 Subject: [PATCH 2/6] improve attention for gpus by using pmap. --- src/maxdiffusion/models/attention_flax.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index db8626984..ec6fac6be 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -180,7 +180,21 @@ def cudnn_flash_attention( key = nn.with_logical_constraint(key, axis_names) value = nn.with_logical_constraint(value, axis_names) - out = self.dpa_layer(query, key, value, mask=None) + @functools.partial( + shard_map.shard_map, + mesh=self.mesh, + in_specs=( + axis_names, + axis_names, + axis_names + ), + out_specs=axis_names, + check_rep=False + ) + def wrap_flash_attention(query, key, value): + return jax.vmap(self.dpa_layer)(query, key, value, mask=None) + + out = wrap_flash_attention(query, key, value)#self.dpa_layer(query, key, value, mask=None) return self.reshape_data_from_cudnn_flash(out) def apply_attention_dot(self, query: Array, key: Array, value: Array): From 39b38cf5ee36e75b710ef5b16e673c734b8cfe1d Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 13 Mar 2025 18:16:38 +0000 Subject: [PATCH 3/6] update readme to include training on GPUs. Revert max_utils jitting of state. --- README.md | 14 +++++++++++--- src/maxdiffusion/max_utils.py | 24 +++++++++++++----------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 09492f921..48730b745 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ MaxDiffusion supports - [Dreambooth](#dreambooth) - [Inference](#inference) - [Flux](#flux) - - [Flash Attention for GPU:](#flash-attention-for-gpu) + - [Fused Attention for GPU:](#fused-attention-for-gpu) - [Hyper SDXL LoRA](#hyper-sdxl-lora) - [Load Multiple LoRA](#load-multiple-lora) - [SDXL Lightning](#sdxl-lightning) @@ -83,6 +83,14 @@ After installation completes, run the training script. python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_xl_run" output_dir="gs://your-bucket/" per_device_batch_size=1 ``` + On GPUS with Fused Attention: + + First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu). + + ```bash + NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=float16 activations_dtype=float16 per_device_batch_size=1 attention="cudnn_flash_te" + ``` + To generate images with a trained checkpoint, run: ```bash @@ -176,8 +184,8 @@ To generate images, run the following command: python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False ``` - ## Flash Attention for GPU: - Flash Attention for GPU is supported via TransformerEngine. Installation instructions: + ## Fused Attention for GPU: + Fused Attention for GPU is supported via TransformerEngine. Installation instructions: ```bash cd maxdiffusion diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 2dbcbaff1..32815f308 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -404,18 +404,20 @@ def setup_initial_state( state = state[checkpoint_item] if not state: max_logging.log(f"Could not find the item in orbax, creating state...") - state = init_train_state( - model=model, - tx=tx, - weights_init_fn=weights_init_fn, - params=model_params, - training=training, - eval_only=False + init_train_state_partial = functools.partial( + init_train_state, + model=model, + tx=tx, + weights_init_fn=weights_init_fn, + params=model_params, + training=training, + eval_only=False, ) - if model_params: - state = state.replace(params=model_params) - - state = jax.device_put(state, state_mesh_shardings) + state = jax.jit( + init_train_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )() state = unbox_logicallypartioned_trainstate(state) From 618550c1835a9602affe064106c829629a2756bd Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 13 Mar 2025 20:51:15 +0000 Subject: [PATCH 4/6] remove commented out line. --- src/maxdiffusion/models/attention_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 76f6cc0c6..97396fe83 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -206,7 +206,7 @@ def cudnn_flash_attention( def wrap_flash_attention(query, key, value): return jax.vmap(self.dpa_layer)(query, key, value, mask=None) - out = wrap_flash_attention(query, key, value)#self.dpa_layer(query, key, value, mask=None) + out = wrap_flash_attention(query, key, value) return self.reshape_data_from_cudnn_flash(out) def apply_attention_dot(self, query: Array, key: Array, value: Array): From 0db81e36410c68ef14e84c3172d77b53fc451cd6 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Thu, 13 Mar 2025 21:49:27 +0000 Subject: [PATCH 5/6] lint --- src/maxdiffusion/models/attention_flax.py | 14 +++++--------- src/maxdiffusion/train_sdxl.py | 1 + 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 97396fe83..839903406 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -193,15 +193,11 @@ def cudnn_flash_attention( value = nn.with_logical_constraint(value, axis_names) @functools.partial( - shard_map.shard_map, - mesh=self.mesh, - in_specs=( - axis_names, - axis_names, - axis_names - ), - out_specs=axis_names, - check_rep=False + shard_map.shard_map, + mesh=self.mesh, + in_specs=(axis_names, axis_names, axis_names), + out_specs=axis_names, + check_rep=False, ) def wrap_flash_attention(query, key, value): return jax.vmap(self.dpa_layer)(query, key, value, mask=None) diff --git a/src/maxdiffusion/train_sdxl.py b/src/maxdiffusion/train_sdxl.py index 5380659a8..cd8021556 100644 --- a/src/maxdiffusion/train_sdxl.py +++ b/src/maxdiffusion/train_sdxl.py @@ -49,6 +49,7 @@ def main(argv: Sequence[str]) -> None: import os import tensorflow as tf import torch + tf.config.set_visible_devices([], "GPU") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" torch.set_default_device("cpu") From fcafb4acedb50944b2ca0a6681a9ac0c6077dc41 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Fri, 14 Mar 2025 20:42:23 +0000 Subject: [PATCH 6/6] flag to jit initializers --- README.md | 2 +- src/maxdiffusion/configs/base14.yml | 5 +++++ src/maxdiffusion/configs/base21.yml | 4 ++++ src/maxdiffusion/configs/base_2_base.yml | 4 ++++ src/maxdiffusion/configs/base_flux_dev.yml | 4 ++++ src/maxdiffusion/configs/base_flux_schnell.yml | 4 ++++ src/maxdiffusion/configs/base_xl.yml | 4 ++++ src/maxdiffusion/configs/base_xl_lightning.yml | 4 ++++ src/maxdiffusion/max_utils.py | 18 +++++++++++++----- 9 files changed, 43 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 48730b745..ac6c9c3d7 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ After installation completes, run the training script. First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu). ```bash - NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=float16 activations_dtype=float16 per_device_batch_size=1 attention="cudnn_flash_te" + NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=bfloat16 resolution=512 per_device_batch_size=1 attention="cudnn_flash_te" jit_initializers=False ``` To generate images with a trained checkpoint, run: diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index bf3bc8569..1768bbede 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -37,6 +37,11 @@ activations_dtype: 'bfloat16' # fp32 activations and fp32 weights with HIGHEST will provide the best precision # at the cost of time. precision: "DEFAULT" + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: False split_head_dim: True diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index 13cbd80c5..4ff025c47 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -37,6 +37,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: False split_head_dim: True diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index f648a4464..4cf66f5d8 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -38,6 +38,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: True split_head_dim: True diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 97060eed3..944153d64 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -51,6 +51,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: True split_head_dim: True diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index bafac7ce9..3106255a9 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -50,6 +50,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: True split_head_dim: True diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index d23cceab0..24cfe3997 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -38,6 +38,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: False split_head_dim: True diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index c9eba9df1..60f6fb873 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -36,6 +36,10 @@ activations_dtype: 'bfloat16' # at the cost of time. precision: "DEFAULT" +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + # Set true to load weights from pytorch from_pt: False split_head_dim: True diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 32815f308..2d37e9416 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -413,11 +413,19 @@ def setup_initial_state( training=training, eval_only=False, ) - state = jax.jit( - init_train_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )() + if config.jit_initializers: + state = jax.jit( + init_train_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + else: + state = init_train_state_partial() + if model_params: + state = state.replace(params=model_params) + state = jax.device_put(state, state_mesh_shardings) + if model_params: + state = state.replace(params=model_params) state = unbox_logicallypartioned_trainstate(state)