File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -79,10 +79,16 @@ After installation completes, run the training script.
7979
8080- ** Flux**
8181
82+ Expected results on 1024 x 1024 images with flash attention and bfloat16:
83+
84+ | Model | Accelerator | Sharding Strategy | Per Device Batch Size | Global Batch Size | Step Time (secs) |
85+ | --- | --- | --- | --- | --- | --- |
86+ | Flux-dev | v5p-8 | DDP | 1 | 4 | 1.31 |
87+
8288 Flux finetuning has only been tested on TPU v5p.
8389
8490 ``` bash
85- python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml run_name=" test-flux-train" output_dir=" gs://<your-gcs-bucket>/" save_final_checkpoint=True jax_cache_dir=" /tmp/jax_cache" max_train_steps=4500
91+ python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml run_name=" test-flux-train" output_dir=" gs://<your-gcs-bucket>/" save_final_checkpoint=True jax_cache_dir=" /tmp/jax_cache"
8692 ```
8793
8894 To generate images with a finetuned checkpoint, run:
You can’t perform that action at this time.
0 commit comments