|
17 | 17 | [](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml) |
18 | 18 |
|
19 | 19 | # What's new? |
| 20 | +- **`2025/02/08**: Flux schnell & dev inference. |
20 | 21 | - **`2024/12/12`**: Load multiple LoRAs for inference. |
21 | 22 | - **`2024/10/22`**: LoRA support for Hyper SDXL. |
22 | 23 | - **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format. |
@@ -46,6 +47,7 @@ MaxDiffusion supports |
46 | 47 | * [Training](#training) |
47 | 48 | * [Dreambooth](#dreambooth) |
48 | 49 | * [Inference](#inference) |
| 50 | + * [Flux](#flux) |
49 | 51 | * [Hyper-SD XL LoRA](#hyper-sdxl-lora) |
50 | 52 | * [Load Multiple LoRA](#load-multiple-lora) |
51 | 53 | * [SDXL Lightning](#sdxl-lightning) |
@@ -133,6 +135,39 @@ To generate images, run the following command: |
133 | 135 | ```bash |
134 | 136 | python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run" |
135 | 137 | ``` |
| 138 | + ## Flux |
| 139 | + |
| 140 | + First make sure you have permissions to access the Flux repos in Huggingface. |
| 141 | + |
| 142 | + Expected results on 1024 x 1024 images with flash attention and bfloat16: |
| 143 | + |
| 144 | + | Model | Accelerator | Sharding Strategy | Batch Size | Steps | time (secs) | |
| 145 | + | --- | --- | --- | --- | --- | --- | |
| 146 | + | Flux-dev | v4-8 | DDP | 4 | 28 | 23 | |
| 147 | + | Flux-schnell | v4-8 | DDP | 4 | 4 | 2.2 | |
| 148 | + | Flux-dev | v6e-4 | DDP | 4 | 28 | 5.5 | |
| 149 | + | Flux-schnell | v6e-4 | DDP | 4 | 4 | 0.8 | |
| 150 | + | Flux-schnell | v6e-4 | FSDP | 4 | 4 | 1.2 | |
| 151 | + |
| 152 | + Schnell: |
| 153 | + |
| 154 | + ```bash |
| 155 | + python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 |
| 156 | + ``` |
| 157 | + |
| 158 | + Dev: |
| 159 | + |
| 160 | + ```bash |
| 161 | + python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 |
| 162 | + ``` |
| 163 | + |
| 164 | + If you are using a TPU v6e (Trillium), you can use optimized flash block sizes for faster inference. Uncomment Flux-dev [config](src/maxdiffusion/configs/base_flux_dev.yml#60) and Flux-schnell [config](src/maxdiffusion/configs/base_flux_schnell.yml#68) |
| 165 | + |
| 166 | + To keep text encoders, vae and transformer on HBM memory at all times, the following command shards the model across devices. |
| 167 | + |
| 168 | + ```bash |
| 169 | + python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False |
| 170 | + ``` |
136 | 171 |
|
137 | 172 | ## Hyper SDXL LoRA |
138 | 173 |
|
|
0 commit comments