diff --git a/README.md b/README.md index ac6c9c3d7..081d65e8e 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? +- **`2025/04/17`**: Flux Finetuning. - **`2025/02/12`**: Flux LoRA for inference. - **`2025/02/08`**: Flux schnell & dev inference. - **`2024/12/12`**: Load multiple LoRAs for inference. @@ -76,6 +77,26 @@ For your first time running Maxdiffusion, we provide specific [instructions](doc After installation completes, run the training script. +- **Flux** + + Expected results on 1024 x 1024 images with flash attention and bfloat16: + + | Model | Accelerator | Sharding Strategy | Per Device Batch Size | Global Batch Size | Step Time (secs) | + | --- | --- | --- | --- | --- | --- | + | Flux-dev | v5p-8 | DDP | 1 | 4 | 1.31 | + + Flux finetuning has only been tested on TPU v5p. + + ```bash + python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml run_name="test-flux-train" output_dir="gs:///" save_final_checkpoint=True jax_cache_dir="/tmp/jax_cache" + ``` + + To generate images with a finetuned checkpoint, run: + + ```bash + python src/maxdiffusion/generate_flux_pipeline.py src/maxdiffusion/configs/base_flux_dev.yml run_name="test-flux-train" output_dir="gs:///" jax_cache_dir="/tmp/jax_cache" + ``` + - **Stable Diffusion XL** ```bash diff --git a/src/maxdiffusion/checkpointing/checkpointing_utils.py b/src/maxdiffusion/checkpointing/checkpointing_utils.py index aa68267a5..dd78eaa6c 100644 --- a/src/maxdiffusion/checkpointing/checkpointing_utils.py +++ b/src/maxdiffusion/checkpointing/checkpointing_utils.py @@ -143,7 +143,8 @@ def load_params_from_path( ckpt_path = os.path.join(config.checkpoint_dir, str(step), checkpoint_item) ckpt_path = epath.Path(ckpt_path) - ckpt_path = os.path.abspath(ckpt_path) + if not ckpt_path.as_uri().startswith("gs://"): + ckpt_path = os.path.abspath(ckpt_path) restore_args = ocp.checkpoint_utils.construct_restore_args(unboxed_abstract_params) restored = ckptr.restore( diff --git a/src/maxdiffusion/configs/README.md b/src/maxdiffusion/configs/README.md index a9d13aabe..a052df291 100644 --- a/src/maxdiffusion/configs/README.md +++ b/src/maxdiffusion/configs/README.md @@ -12,4 +12,12 @@ base_2_base.yml - used for training and inference using [stable-diffusion-2-base ## Stable Diffusion XL & SDXL Lightning -base_xl.yml - used to run inference using [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) \ No newline at end of file +base_xl.yml - used to run inference using [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) + +base_xl_lightning.yml - used to run inference using [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) + +## Flux + +base_flux_dev.yml - used for training and inference using [Flux Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) + +base_flux_schnell.yml - used for training and inference using [Flux Schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) \ No newline at end of file diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 167dc8bc8..187bade09 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -177,7 +177,7 @@ hf_train_files: '' hf_access_token: '' image_column: 'image' caption_column: 'text' -resolution: 512 +resolution: 1024 center_crop: False random_flip: False # If cache_latents_text_encoder_outputs is True diff --git a/src/maxdiffusion/generate_flux_pipeline.py b/src/maxdiffusion/generate_flux_pipeline.py index 8887375d0..6ee469728 100644 --- a/src/maxdiffusion/generate_flux_pipeline.py +++ b/src/maxdiffusion/generate_flux_pipeline.py @@ -96,13 +96,13 @@ def run(config): t0 = time.perf_counter() with ExitStack(): - imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready() + imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready() t1 = time.perf_counter() max_logging.log(f"Compile time: {t1 - t0:.1f}s.") t0 = time.perf_counter() with ExitStack(): - imgs = pipeline(flux_params=flux_state, timesteps=50, vae_params=vae_state).block_until_ready() + imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready() imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True) t1 = time.perf_counter() max_logging.log(f"Inference time: {t1 - t0:.1f}s.") diff --git a/src/maxdiffusion/pipelines/flux/flux_pipeline.py b/src/maxdiffusion/pipelines/flux/flux_pipeline.py index e655b491d..112338d57 100644 --- a/src/maxdiffusion/pipelines/flux/flux_pipeline.py +++ b/src/maxdiffusion/pipelines/flux/flux_pipeline.py @@ -102,7 +102,7 @@ def vae_decode(self, latents, vae, state, config): return img def vae_encode(self, latents, vae, state): - img = vae.apply({"params": state["params"]}, latents, deterministic=True, method=vae.encode).latent_dist.mode() + img = vae.apply({"params": state.params}, latents, deterministic=True, method=vae.encode).latent_dist.mode() img = vae.config.scaling_factor * (img - vae.config.shift_factor) return img