Skip to content

Commit 39b38cf

Browse files
committed
update readme to include training on GPUs. Revert max_utils jitting of state.
1 parent e0a538f commit 39b38cf

2 files changed

Lines changed: 24 additions & 14 deletions

File tree

README.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ MaxDiffusion supports
5353
- [Dreambooth](#dreambooth)
5454
- [Inference](#inference)
5555
- [Flux](#flux)
56-
- [Flash Attention for GPU:](#flash-attention-for-gpu)
56+
- [Fused Attention for GPU:](#fused-attention-for-gpu)
5757
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
5858
- [Load Multiple LoRA](#load-multiple-lora)
5959
- [SDXL Lightning](#sdxl-lightning)
@@ -83,6 +83,14 @@ After installation completes, run the training script.
8383
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_xl_run" output_dir="gs://your-bucket/" per_device_batch_size=1
8484
```
8585

86+
On GPUS with Fused Attention:
87+
88+
First install Transformer Engine by following the [instructions here](#fused-attention-for-gpu).
89+
90+
```bash
91+
NVTE_FUSED_ATTN=1 python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml hardware=gpu run_name='test-sdxl-train' output_dir=/tmp/ train_text_encoder=false cache_latents_text_encoder_outputs=true max_train_steps=200 weights_dtype=float16 activations_dtype=float16 per_device_batch_size=1 attention="cudnn_flash_te"
92+
```
93+
8694
To generate images with a trained checkpoint, run:
8795

8896
```bash
@@ -176,8 +184,8 @@ To generate images, run the following command:
176184
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
177185
```
178186

179-
## Flash Attention for GPU:
180-
Flash Attention for GPU is supported via TransformerEngine. Installation instructions:
187+
## Fused Attention for GPU:
188+
Fused Attention for GPU is supported via TransformerEngine. Installation instructions:
181189

182190
```bash
183191
cd maxdiffusion

src/maxdiffusion/max_utils.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -404,18 +404,20 @@ def setup_initial_state(
404404
state = state[checkpoint_item]
405405
if not state:
406406
max_logging.log(f"Could not find the item in orbax, creating state...")
407-
state = init_train_state(
408-
model=model,
409-
tx=tx,
410-
weights_init_fn=weights_init_fn,
411-
params=model_params,
412-
training=training,
413-
eval_only=False
407+
init_train_state_partial = functools.partial(
408+
init_train_state,
409+
model=model,
410+
tx=tx,
411+
weights_init_fn=weights_init_fn,
412+
params=model_params,
413+
training=training,
414+
eval_only=False,
414415
)
415-
if model_params:
416-
state = state.replace(params=model_params)
417-
418-
state = jax.device_put(state, state_mesh_shardings)
416+
state = jax.jit(
417+
init_train_state_partial,
418+
in_shardings=None,
419+
out_shardings=state_mesh_shardings,
420+
)()
419421

420422
state = unbox_logicallypartioned_trainstate(state)
421423

0 commit comments

Comments
 (0)