Skip to content

Commit 064a3a7

Browse files
committed
update readme and some dependencies.
1 parent e56825f commit 064a3a7

2 files changed

Lines changed: 37 additions & 2 deletions

File tree

README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2025/02/08**: Flux schnell & dev inference.
2021
- **`2024/12/12`**: Load multiple LoRAs for inference.
2122
- **`2024/10/22`**: LoRA support for Hyper SDXL.
2223
- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
@@ -46,6 +47,7 @@ MaxDiffusion supports
4647
* [Training](#training)
4748
* [Dreambooth](#dreambooth)
4849
* [Inference](#inference)
50+
* [Flux](#flux)
4951
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
5052
* [Load Multiple LoRA](#load-multiple-lora)
5153
* [SDXL Lightning](#sdxl-lightning)
@@ -133,6 +135,39 @@ To generate images, run the following command:
133135
```bash
134136
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
135137
```
138+
## Flux
139+
140+
First make sure you have permissions to access the Flux repos in Huggingface.
141+
142+
Expected results on 1024 x 1024 images with flash attention and bfloat16:
143+
144+
| Model | Accelerator | Sharding Strategy | Batch Size | Steps | time (secs) |
145+
| --- | --- | --- | --- | --- | --- |
146+
| Flux-dev | v4-8 | DDP | 4 | 28 | 23 |
147+
| Flux-schnell | v4-8 | DDP | 4 | 4 | 2.2 |
148+
| Flux-dev | v6e-4 | DDP | 4 | 28 | 5.5 |
149+
| Flux-schnell | v6e-4 | DDP | 4 | 4 | 0.8 |
150+
| Flux-schnell | v6e-4 | FSDP | 4 | 4 | 1.2 |
151+
152+
Schnell:
153+
154+
```bash
155+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1
156+
```
157+
158+
Dev:
159+
160+
```bash
161+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1
162+
```
163+
164+
If you are using a TPU v6e (Trillium), you can use optimized flash block sizes for faster inference. Uncomment Flux-dev [config](src/maxdiffusion/configs/base_flux_dev.yml#60) and Flux-schnell [config](src/maxdiffusion/configs/base_flux_schnell.yml#68)
165+
166+
To keep text encoders, vae and transformer on HBM memory at all times, the following command shards the model across devices.
167+
168+
```bash
169+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
170+
```
136171

137172
## Hyper SDXL LoRA
138173

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ absl-py
66
datasets
77
flax>=0.10.2
88
optax>=0.2.3
9-
torch>=2.3.1
10-
torchvision>=0.18.1
9+
torch==2.5.1
10+
torchvision==0.20.1
1111
ftfy
1212
tensorboard>=2.17.0
1313
tensorboardx==2.6.2.2

0 commit comments

Comments
 (0)