Skip to content

Commit 30b20bd

Browse files
Merge branch 'main' into wan
2 parents 15d242e + da779ea commit 30b20bd

22 files changed

Lines changed: 462 additions & 38 deletions

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ MaxDiffusion supports
5353
- [Dreambooth](#dreambooth)
5454
- [Inference](#inference)
5555
- [Flux](#flux)
56-
- [Flash Attention for GPU:](#flash-attention-for-gpu)
56+
- [Fused Attention for GPU:](#fused-attention-for-gpu)
5757
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
5858
- [Load Multiple LoRA](#load-multiple-lora)
5959
- [SDXL Lightning](#sdxl-lightning)
@@ -83,6 +83,14 @@ After installation completes, run the training script.
8383
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_xl_run" output_dir="gs://your-bucket/" per_device_batch_size=1
8484
```
8585

86+
On GPUS with Fused Attention:
87+
88+
First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu).
89+
90+
```bash
91+
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=bfloat16 resolution=512 per_device_batch_size=1 attention="cudnn_flash_te" jit_initializers=False
92+
```
93+
8694
To generate images with a trained checkpoint, run:
8795

8896
```bash
@@ -176,8 +184,8 @@ To generate images, run the following command:
176184
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
177185
```
178186

179-
## Flash Attention for GPU:
180-
Flash Attention for GPU is supported via TransformerEngine. Installation instructions:
187+
## Fused Attention for GPU:
188+
Fused Attention for GPU is supported via TransformerEngine. Installation instructions:
181189

182190
```bash
183191
cd maxdiffusion

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ tensorflow-datasets>=4.9.6
2424
ruff>=0.1.5,<=0.2
2525
git+https://github.com/mlperf/logging.git
2626
opencv-python-headless==4.10.0.84
27-
orbax-checkpoint==0.10.2
27+
orbax-checkpoint==0.10.3
2828
tokenizers==0.21.0
2929
huggingface_hub==0.24.7
3030
transformers==4.48.1
3131
einops==0.8.0
3232
sentencepiece
33+
aqtp

requirements_with_jax_stable_stack.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ flax>=0.10.2
77
ftfy
88
git+https://github.com/mlperf/logging.git
99
google-cloud-storage==2.17.0
10-
grain-nightly
10+
grain-nightly==0.0.10
1111
huggingface_hub==0.24.7
1212
jax>=0.4.30
1313
jaxlib>=0.4.30
1414
Jinja2
1515
opencv-python-headless==4.10.0.84
1616
optax>=0.2.3
17-
orbax-checkpoint==0.10.2
17+
orbax-checkpoint==0.10.3
1818
parameterized
1919
Pillow
2020
pyink

src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,11 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training)
8888
config=self.config,
8989
mesh=self.mesh,
9090
weights_init_fn=weights_init_fn,
91-
model_params=None,
91+
model_params=None if self.config.train_new_unet else params.get("unet", None),
9292
checkpoint_manager=self.checkpoint_manager,
9393
checkpoint_item=checkpoint_item_name,
9494
training=is_training,
9595
)
96-
if not self.config.train_new_unet:
97-
unet_state = unet_state.replace(params=params.get("unet", None))
98-
unet_state = jax.device_put(unet_state, state_mesh_shardings)
9996
return unet_state, state_mesh_shardings, learning_rate_scheduler
10097

10198
def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):
@@ -153,20 +150,18 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is
153150
input_shape=(self.total_train_batch_size, pipeline.tokenizer.model_max_length),
154151
)
155152

156-
state, state_mesh_shardings = max_utils.setup_initial_state(
153+
# state, state_mesh_shardings =
154+
return max_utils.setup_initial_state(
157155
model=pipeline.text_encoder_2,
158156
tx=tx,
159157
config=self.config,
160158
mesh=self.mesh,
161159
weights_init_fn=weights_init_fn,
162-
model_params=None,
160+
model_params=params.get("text_encoder_2", None),
163161
checkpoint_manager=self.checkpoint_manager,
164162
checkpoint_item=checkpoint_item_name,
165163
training=is_training,
166164
)
167-
state = state.replace(params=params.get("text_encoder_2", None))
168-
state = jax.device_put(state, state_mesh_shardings)
169-
return state, state_mesh_shardings
170165

171166
def restore_data_iterator_state(self, data_iterator):
172167
if (

src/maxdiffusion/configs/base14.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ activations_dtype: 'bfloat16'
3737
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
3838
# at the cost of time.
3939
precision: "DEFAULT"
40+
41+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
42+
# It must be True for multi-host.
43+
jit_initializers: True
44+
4045
# Set true to load weights from pytorch
4146
from_pt: False
4247
split_head_dim: True
@@ -216,3 +221,7 @@ prior_loss_weight: 1.0
216221
num_class_images: 100
217222
# If true, set dataset_save_location.
218223
cache_dreambooth_dataset: False
224+
quantization: ''
225+
# Shard the range finding operation for quantization. By default this is set to number of slices.
226+
quantization_local_shard_count: -1
227+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base21.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ activations_dtype: 'bfloat16'
3737
# at the cost of time.
3838
precision: "DEFAULT"
3939

40+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
41+
# It must be True for multi-host.
42+
jit_initializers: True
43+
4044
# Set true to load weights from pytorch
4145
from_pt: False
4246
split_head_dim: True
@@ -217,3 +221,7 @@ prior_loss_weight: 1.0
217221
num_class_images: 100
218222
# If true, set dataset_save_location.
219223
cache_dreambooth_dataset: False
224+
quantization: ''
225+
# Shard the range finding operation for quantization. By default this is set to number of slices.
226+
quantization_local_shard_count: -1
227+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ activations_dtype: 'bfloat16'
3838
# at the cost of time.
3939
precision: "DEFAULT"
4040

41+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
42+
# It must be True for multi-host.
43+
jit_initializers: True
44+
4145
# Set true to load weights from pytorch
4246
from_pt: True
4347
split_head_dim: True
@@ -231,3 +235,8 @@ prior_loss_weight: 1.0
231235
num_class_images: 100
232236
# If true, set dataset_save_location.
233237
cache_dreambooth_dataset: False
238+
239+
quantization: ''
240+
# Shard the range finding operation for quantization. By default this is set to number of slices.
241+
quantization_local_shard_count: -1
242+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ activations_dtype: 'bfloat16'
5151
# at the cost of time.
5252
precision: "DEFAULT"
5353

54+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
55+
# It must be True for multi-host.
56+
jit_initializers: True
57+
5458
# Set true to load weights from pytorch
5559
from_pt: True
5660
split_head_dim: True
@@ -260,3 +264,8 @@ controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
260264
controlnet_from_pt: True
261265
controlnet_conditioning_scale: 0.5
262266
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
267+
quantization: ''
268+
# Shard the range finding operation for quantization. By default this is set to number of slices.
269+
quantization_local_shard_count: -1
270+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
271+

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ activations_dtype: 'bfloat16'
5050
# at the cost of time.
5151
precision: "DEFAULT"
5252

53+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
54+
# It must be True for multi-host.
55+
jit_initializers: True
56+
5357
# Set true to load weights from pytorch
5458
from_pt: True
5559
split_head_dim: True
@@ -268,3 +272,7 @@ controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
268272
controlnet_from_pt: True
269273
controlnet_conditioning_scale: 0.5
270274
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
275+
quantization: ''
276+
# Shard the range finding operation for quantization. By default this is set to number of slices.
277+
quantization_local_shard_count: -1
278+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

src/maxdiffusion/configs/base_xl.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ activations_dtype: 'bfloat16'
3838
# at the cost of time.
3939
precision: "DEFAULT"
4040

41+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
42+
# It must be True for multi-host.
43+
jit_initializers: True
44+
4145
# Set true to load weights from pytorch
4246
from_pt: False
4347
split_head_dim: True
@@ -233,3 +237,9 @@ controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0'
233237
controlnet_from_pt: True
234238
controlnet_conditioning_scale: 0.5
235239
controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png'
240+
enable_mllog: False
241+
242+
quantization: ''
243+
# Shard the range finding operation for quantization. By default this is set to number of slices.
244+
quantization_local_shard_count: -1
245+
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.

0 commit comments

Comments
 (0)