Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2025/04/17`**: Flux Finetuning.
- **`2025/02/12`**: Flux LoRA for inference.
- **`2025/02/08`**: Flux schnell & dev inference.
- **`2024/12/12`**: Load multiple LoRAs for inference.
Expand Down Expand Up @@ -76,6 +77,20 @@ For your first time running Maxdiffusion, we provide specific [instructions](doc

After installation completes, run the training script.

- **Flux**

Flux finetuning has only been tested on TPU v5p.

```bash
python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml run_name="test-flux-train" output_dir="gs://<your-gcs-bucket>/" save_final_checkpoint=True jax_cache_dir="/tmp/jax_cache" max_train_steps=4500
Comment thread
coolkp marked this conversation as resolved.
Outdated
```

To generate images with a finetuned checkpoint, run:

```bash
python src/maxdiffusion/generate_flux_pipeline.py src/maxdiffusion/configs/base_flux_dev.yml run_name="test-flux-train" output_dir="gs://<your-gcs-bucket>/" jax_cache_dir="/tmp/jax_cache"
Comment thread
coolkp marked this conversation as resolved.
```

- **Stable Diffusion XL**

```bash
Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def load_params_from_path(

ckpt_path = os.path.join(config.checkpoint_dir, str(step), checkpoint_item)
ckpt_path = epath.Path(ckpt_path)
ckpt_path = os.path.abspath(ckpt_path)
if not ckpt_path.as_uri().startswith("gs://"):
ckpt_path = os.path.abspath(ckpt_path)

restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params)
restored = ckptr.restore(
Expand Down
10 changes: 9 additions & 1 deletion src/maxdiffusion/configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,12 @@ base_2_base.yml - used for training and inference using [stable-diffusion-2-base

## Stable Diffusion XL & SDXL Lightning

base_xl.yml - used to run inference using [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
base_xl.yml - used to run inference using [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)

base_xl_lightning.yml - used to run inference using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)

## Flux

base_flux_dev.yml - used for training and inference using [Flux Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)

base_flux_schnell.yml - used for training and inference using [Flux Schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell)
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ hf_train_files: ''
hf_access_token: ''
image_column: 'image'
caption_column: 'text'
resolution: 512
resolution: 1024
center_crop: False
random_flip: False
# If cache_latents_text_encoder_outputs is True
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/generate_flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ def run(config):

t0 = time.perf_counter()
with ExitStack():
imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready()
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
t1 = time.perf_counter()
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")

t0 = time.perf_counter()
with ExitStack():
imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready()
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
t1 = time.perf_counter()
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/pipelines/flux/flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def vae_decode(self, latents, vae, state, config):
return img

def vae_encode(self, latents, vae, state):
img = vae.apply({"params": state["params"]}, latents, deterministic=True, method=vae.encode).latent_dist.mode()
img = vae.apply({"params": state.params}, latents, deterministic=True, method=vae.encode).latent_dist.mode()
img = vae.config.scaling_factor * (img - vae.config.shift_factor)
return img

Expand Down
Loading