Skip to content

Commit 7f0f5bc

Browse files
Flux inference implementation (#146)
Adds support for Flux dev and schnell. --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com>
1 parent 271ce08 commit 7f0f5bc

112 files changed

Lines changed: 2528 additions & 75417 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2025/02/08`**: Flux schnell & dev inference.
2021
- **`2024/12/12`**: Load multiple LoRAs for inference.
2122
- **`2024/10/22`**: LoRA support for Hyper SDXL.
2223
- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format.
@@ -46,6 +47,7 @@ MaxDiffusion supports
4647
* [Training](#training)
4748
* [Dreambooth](#dreambooth)
4849
* [Inference](#inference)
50+
* [Flux](#flux)
4951
* [Hyper-SD XL LoRA](#hyper-sdxl-lora)
5052
* [Load Multiple LoRA](#load-multiple-lora)
5153
* [SDXL Lightning](#sdxl-lightning)
@@ -133,6 +135,39 @@ To generate images, run the following command:
133135
```bash
134136
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
135137
```
138+
## Flux
139+
140+
First make sure you have permissions to access the Flux repos in Huggingface.
141+
142+
Expected results on 1024 x 1024 images with flash attention and bfloat16:
143+
144+
| Model | Accelerator | Sharding Strategy | Batch Size | Steps | time (secs) |
145+
| --- | --- | --- | --- | --- | --- |
146+
| Flux-dev | v4-8 | DDP | 4 | 28 | 23 |
147+
| Flux-schnell | v4-8 | DDP | 4 | 4 | 2.2 |
148+
| Flux-dev | v6e-4 | DDP | 4 | 28 | 5.5 |
149+
| Flux-schnell | v6e-4 | DDP | 4 | 4 | 0.8 |
150+
| Flux-schnell | v6e-4 | FSDP | 4 | 4 | 1.2 |
151+
152+
Schnell:
153+
154+
```bash
155+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1
156+
```
157+
158+
Dev:
159+
160+
```bash
161+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1
162+
```
163+
164+
If you are using a TPU v6e (Trillium), you can use optimized flash block sizes for faster inference. Uncomment Flux-dev [config](src/maxdiffusion/configs/base_flux_dev.yml#60) and Flux-schnell [config](src/maxdiffusion/configs/base_flux_schnell.yml#68)
165+
166+
To keep text encoders, vae and transformer on HBM memory at all times, the following command shards the model across devices.
167+
168+
```bash
169+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
170+
```
136171

137172
## Hyper SDXL LoRA
138173

requirements.txt

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ absl-py
66
datasets
77
flax>=0.10.2
88
optax>=0.2.3
9-
torch>=2.3.1
10-
torchvision>=0.18.1
9+
torch==2.5.1
10+
torchvision==0.20.1
1111
ftfy
12-
tensorboard==2.17.0
12+
tensorboard>=2.17.0
1313
tensorboardx==2.6.2.2
1414
tensorboard-plugin-profile==2.15.2
1515
Jinja2
@@ -25,5 +25,8 @@ ruff>=0.1.5,<=0.2
2525
git+https://github.com/mlperf/logging.git
2626
opencv-python-headless==4.10.0.84
2727
orbax-checkpoint==0.10.2
28-
tokenizers==0.20.0
29-
huggingface_hub==0.24.7
28+
tokenizers==0.21.0
29+
huggingface_hub==0.24.7
30+
transformers==4.48.1
31+
einops==0.8.0
32+
sentencepiece

src/maxdiffusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@
372372
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
373373
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
374374
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
375+
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
375376
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
376377
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
377378
_import_structure["schedulers"].extend(
@@ -451,6 +452,7 @@
451452
from .models.controlnet_flax import FlaxControlNetModel
452453
from .models.modeling_flax_utils import FlaxModelMixin
453454
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
455+
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
454456
from .models.vae_flax import FlaxAutoencoderKL
455457
from .pipelines import FlaxDiffusionPipeline
456458
from .schedulers import (

src/maxdiffusion/checkpointing/base_stable_diffusion_checkpointer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
max_logging,
3333
)
3434

35-
from maxdiffusion.transformers import (CLIPTokenizer, FlaxCLIPTextModel, CLIPTextConfig, FlaxCLIPTextModelWithProjection)
35+
from transformers import (CLIPTokenizer, FlaxCLIPTextModel, CLIPTextConfig, FlaxCLIPTextModelWithProjection)
3636

3737
from maxdiffusion.checkpointing.checkpointing_utils import (
3838
create_orbax_checkpoint_manager,
@@ -88,11 +88,14 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training)
8888
config=self.config,
8989
mesh=self.mesh,
9090
weights_init_fn=weights_init_fn,
91-
model_params=None if self.config.train_new_unet else params.get("unet", None),
91+
model_params=None,
9292
checkpoint_manager=self.checkpoint_manager,
9393
checkpoint_item=checkpoint_item_name,
9494
training=is_training,
9595
)
96+
if not self.config.train_new_unet:
97+
unet_state = unet_state.replace(params=params.get("unet", None))
98+
unet_state = jax.device_put(unet_state, state_mesh_shardings)
9699
return unet_state, state_mesh_shardings, learning_rate_scheduler
97100

98101
def create_vae_state(self, pipeline, params, checkpoint_item_name, is_training=False):
@@ -150,17 +153,20 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is
150153
input_shape=(self.total_train_batch_size, pipeline.tokenizer.model_max_length),
151154
)
152155

153-
return max_utils.setup_initial_state(
156+
state, state_mesh_shardings = max_utils.setup_initial_state(
154157
model=pipeline.text_encoder_2,
155158
tx=tx,
156159
config=self.config,
157160
mesh=self.mesh,
158161
weights_init_fn=weights_init_fn,
159-
model_params=params.get("text_encoder_2", None),
162+
model_params=None,
160163
checkpoint_manager=self.checkpoint_manager,
161164
checkpoint_item=checkpoint_item_name,
162165
training=is_training,
163166
)
167+
state = state.replace(params=params.get("text_encoder_2", None))
168+
state = jax.device_put(state, state_mesh_shardings)
169+
return state, state_mesh_shardings
164170

165171
def restore_data_iterator_state(self, data_iterator):
166172
if (
@@ -302,15 +308,16 @@ def load_checkpoint(self, step=None, scheduler_class=None):
302308
tokenizer_path = os.path.join(tokenizer_path, "tokenizer")
303309
tokenizer_path = max_utils.download_blobs(tokenizer_path, "/tmp")
304310
tokenizer = CLIPTokenizer.from_pretrained(
305-
tokenizer_path, subfolder="tokenizer", dtype=self.config.activations_dtype, weights_dtype=self.config.weights_dtype
311+
tokenizer_path,
312+
subfolder="tokenizer",
313+
dtype=self.config.activations_dtype,
306314
)
307315

308316
te_pretrained_config = CLIPTextConfig(**model_configs[0]["text_encoder_config"])
309317
text_encoder = FlaxCLIPTextModel(
310318
te_pretrained_config,
311319
seed=self.config.seed,
312320
dtype=self.config.activations_dtype,
313-
weights_dtype=self.config.weights_dtype,
314321
_do_init=False,
315322
)
316323

src/maxdiffusion/common_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
BATCH = "activation_batch"
3838
LENGTH = "activation_length"
39+
EMBED = "activation_embed"
3940
HEAD = "activation_heads"
4041
D_KV = "activation_kv"
4142
KEEP_1 = "activation_keep_1"

0 commit comments

Comments
 (0)