Skip to content

Commit 8b77101

Browse files
Experimental gpu sdxl (#157)
* tf processing on cpu * improve attention for gpus by using shard_map. * update readme to include training on GPUs. Revert max_utils jitting of state. * remove commented out line. * lint * flag to jit initializers --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com>
1 parent 3b4f4d5 commit 8b77101

12 files changed

Lines changed: 84 additions & 22 deletions

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ MaxDiffusion supports
5353
- [Dreambooth](#dreambooth)
5454
- [Inference](#inference)
5555
- [Flux](#flux)
56-
- [Flash Attention for GPU:](#flash-attention-for-gpu)
56+
- [Fused Attention for GPU:](#fused-attention-for-gpu)
5757
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
5858
- [Load Multiple LoRA](#load-multiple-lora)
5959
- [SDXL Lightning](#sdxl-lightning)
@@ -83,6 +83,14 @@ After installation completes, run the training script.
8383
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_xl_run" output_dir="gs://your-bucket/" per_device_batch_size=1
8484
```
8585

86+
On GPUS with Fused Attention:
87+
88+
First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu).
89+
90+
```bash
91+
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_new_unet=true train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=bfloat16 resolution=512 per_device_batch_size=1 attention="cudnn_flash_te" jit_initializers=False
92+
```
93+
8694
To generate images with a trained checkpoint, run:
8795

8896
```bash
@@ -176,8 +184,8 @@ To generate images, run the following command:
176184
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
177185
```
178186

179-
## Flash Attention for GPU:
180-
Flash Attention for GPU is supported via TransformerEngine. Installation instructions:
187+
## Fused Attention for GPU:
188+
Fused Attention for GPU is supported via TransformerEngine. Installation instructions:
181189

182190
```bash
183191
cd maxdiffusion

src/maxdiffusion/configs/base14.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ activations_dtype: 'bfloat16'
3737
# fp32 activations and fp32 weights with HIGHEST will provide the best precision
3838
# at the cost of time.
3939
precision: "DEFAULT"
40+
41+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
42+
# It must be True for multi-host.
43+
jit_initializers: True
44+
4045
# Set true to load weights from pytorch
4146
from_pt: False
4247
split_head_dim: True

src/maxdiffusion/configs/base21.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ activations_dtype: 'bfloat16'
3737
# at the cost of time.
3838
precision: "DEFAULT"
3939

40+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
41+
# It must be True for multi-host.
42+
jit_initializers: True
43+
4044
# Set true to load weights from pytorch
4145
from_pt: False
4246
split_head_dim: True

src/maxdiffusion/configs/base_2_base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ activations_dtype: 'bfloat16'
3838
# at the cost of time.
3939
precision: "DEFAULT"
4040

41+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
42+
# It must be True for multi-host.
43+
jit_initializers: True
44+
4145
# Set true to load weights from pytorch
4246
from_pt: True
4347
split_head_dim: True

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ activations_dtype: 'bfloat16'
5151
# at the cost of time.
5252
precision: "DEFAULT"
5353

54+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
55+
# It must be True for multi-host.
56+
jit_initializers: True
57+
5458
# Set true to load weights from pytorch
5559
from_pt: True
5660
split_head_dim: True

src/maxdiffusion/configs/base_flux_schnell.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ activations_dtype: 'bfloat16'
5050
# at the cost of time.
5151
precision: "DEFAULT"
5252

53+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
54+
# It must be True for multi-host.
55+
jit_initializers: True
56+
5357
# Set true to load weights from pytorch
5458
from_pt: True
5559
split_head_dim: True

src/maxdiffusion/configs/base_xl.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ activations_dtype: 'bfloat16'
3838
# at the cost of time.
3939
precision: "DEFAULT"
4040

41+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
42+
# It must be True for multi-host.
43+
jit_initializers: True
44+
4145
# Set true to load weights from pytorch
4246
from_pt: False
4347
split_head_dim: True

src/maxdiffusion/configs/base_xl_lightning.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ activations_dtype: 'bfloat16'
3636
# at the cost of time.
3737
precision: "DEFAULT"
3838

39+
# if False state is not jitted and instead replicate is called. This is good for debugging on single host
40+
# It must be True for multi-host.
41+
jit_initializers: True
42+
3943
# Set true to load weights from pytorch
4044
from_pt: False
4145
split_head_dim: True

src/maxdiffusion/max_utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -413,12 +413,19 @@ def setup_initial_state(
413413
training=training,
414414
eval_only=False,
415415
)
416-
417-
state = jax.jit(
418-
init_train_state_partial,
419-
in_shardings=None,
420-
out_shardings=state_mesh_shardings,
421-
)()
416+
if config.jit_initializers:
417+
state = jax.jit(
418+
init_train_state_partial,
419+
in_shardings=None,
420+
out_shardings=state_mesh_shardings,
421+
)()
422+
else:
423+
state = init_train_state_partial()
424+
if model_params:
425+
state = state.replace(params=model_params)
426+
state = jax.device_put(state, state_mesh_shardings)
427+
if model_params:
428+
state = state.replace(params=model_params)
422429

423430
state = unbox_logicallypartioned_trainstate(state)
424431

src/maxdiffusion/models/attention_flax.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,17 @@ def cudnn_flash_attention(
192192
key = nn.with_logical_constraint(key, axis_names)
193193
value = nn.with_logical_constraint(value, axis_names)
194194

195-
out = self.dpa_layer(query, key, value, mask=None)
195+
@functools.partial(
196+
shard_map.shard_map,
197+
mesh=self.mesh,
198+
in_specs=(axis_names, axis_names, axis_names),
199+
out_specs=axis_names,
200+
check_rep=False,
201+
)
202+
def wrap_flash_attention(query, key, value):
203+
return jax.vmap(self.dpa_layer)(query, key, value, mask=None)
204+
205+
out = wrap_flash_attention(query, key, value)
196206
return self.reshape_data_from_cudnn_flash(out)
197207

198208
def apply_attention_dot(self, query: Array, key: Array, value: Array):

0 commit comments

Comments
 (0)