Skip to content

Commit c8b71f1

Browse files
Merge branch 'main' into multi_res_support
2 parents e2e2c50 + e6f35ce commit c8b71f1

16 files changed

Lines changed: 1462 additions & 30 deletions

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2025/04/17`**: Flux Finetuning.
2021
- **`2025/02/12`**: Flux LoRA for inference.
2122
- **`2025/02/08`**: Flux schnell & dev inference.
2223
- **`2024/12/12`**: Load multiple LoRAs for inference.
@@ -76,6 +77,26 @@ For your first time running Maxdiffusion, we provide specific [instructions](doc
7677

7778
After installation completes, run the training script.
7879

80+
- **Flux**
81+
82+
Expected results on 1024 x 1024 images with flash attention and bfloat16:
83+
84+
| Model | Accelerator | Sharding Strategy | Per Device Batch Size | Global Batch Size | Step Time (secs) |
85+
| --- | --- | --- | --- | --- | --- |
86+
| Flux-dev | v5p-8 | DDP | 1 | 4 | 1.31 |
87+
88+
Flux finetuning has only been tested on TPU v5p.
89+
90+
```bash
91+
python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml run_name="test-flux-train" output_dir="gs://<your-gcs-bucket>/" save_final_checkpoint=True jax_cache_dir="/tmp/jax_cache"
92+
```
93+
94+
To generate images with a finetuned checkpoint, run:
95+
96+
```bash
97+
python src/maxdiffusion/generate_flux_pipeline.py src/maxdiffusion/configs/base_flux_dev.yml run_name="test-flux-train" output_dir="gs://<your-gcs-bucket>/" jax_cache_dir="/tmp/jax_cache"
98+
```
99+
79100
- **Stable Diffusion XL**
80101

81102
```bash

requirements_with_jax_stable_stack.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ flax>=0.10.2
77
ftfy
88
git+https://github.com/mlperf/logging.git
99
google-cloud-storage==2.17.0
10-
grain-nightly
10+
grain-nightly==0.0.10
1111
huggingface_hub==0.24.7
1212
jax>=0.4.30
1313
jaxlib>=0.4.30

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
STABLE_DIFFUSION_CHECKPOINT = "STABLE_DIFFUSION_CHECKPOINT"
3535
STABLE_DIFFUSION_XL_CHECKPOINT = "STABLE_DIFUSSION_XL_CHECKPOINT"
36+
FLUX_CHECKPOINT = "FLUX_CHECKPOINT"
3637

3738

3839
def create_orbax_checkpoint_manager(
@@ -56,17 +57,20 @@ def create_orbax_checkpoint_manager(
5657
max_logging.log(f"checkpoint dir: {checkpoint_dir}")
5758
p = epath.Path(checkpoint_dir)
5859

59-
item_names = (
60-
"unet_config",
61-
"vae_config",
62-
"text_encoder_config",
63-
"scheduler_config",
64-
"unet_state",
65-
"vae_state",
66-
"text_encoder_state",
67-
"tokenizer_config",
68-
)
69-
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
60+
if checkpoint_type == FLUX_CHECKPOINT:
61+
item_names = ("flux_state", "flux_config", "vae_state", "vae_config", "scheduler", "scheduler_config")
62+
else:
63+
item_names = (
64+
"unet_config",
65+
"vae_config",
66+
"text_encoder_config",
67+
"scheduler_config",
68+
"unet_state",
69+
"vae_state",
70+
"text_encoder_state",
71+
"tokenizer_config",
72+
)
73+
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
7074
item_names += (
7175
"text_encoder_2_state",
7276
"text_encoder_2_config",
@@ -117,7 +121,7 @@ def load_stable_diffusion_configs(
117121
"tokenizer_config": orbax.checkpoint.args.JsonRestore(),
118122
}
119123

120-
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT:
124+
if checkpoint_type == STABLE_DIFFUSION_XL_CHECKPOINT or checkpoint_type == FLUX_CHECKPOINT:
121125
restore_args["text_encoder_2_config"] = orbax.checkpoint.args.JsonRestore()
122126

123127
return (checkpoint_manager.restore(step, args=orbax.checkpoint.args.Composite(**restore_args)), None)
@@ -139,6 +143,8 @@ def load_params_from_path(
139143

140144
ckpt_path = os.path.join(config.checkpoint_dir, str(step), checkpoint_item)
141145
ckpt_path = epath.Path(ckpt_path)
146+
if not ckpt_path.as_uri().startswith("gs://"):
147+
ckpt_path = os.path.abspath(ckpt_path)
142148

143149
restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params)
144150
restored = ckptr.restore(

0 commit comments

Comments
 (0)