Skip to content

Commit 286e452

Browse files
committed
initial commit
1 parent 9e74521 commit 286e452

1 file changed

Lines changed: 91 additions & 68 deletions

File tree

README.md

Lines changed: 91 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
limitations under the License.
1515
-->
1616

17-
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)
17+
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported.
2021
- **`2025/8/14`**: LTX-Video img2vid generation is now supported.
2122
- **`2025/7/29`**: LTX-Video text2vid generation is now supported.
2223
- **`2025/04/17`**: Flux Finetuning.
@@ -44,6 +45,7 @@ MaxDiffusion supports
4445
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
4546
* Dreambooth training support for Stable Diffusion 1.x,2.x.
4647
* LTX-Video text2vid, img2vid (inference).
48+
* Wan2.1 text2vid (training and inference).
4749

4850

4951
# Table of Contents
@@ -54,15 +56,23 @@ MaxDiffusion supports
5456
- [Getting Started](#getting-started)
5557
- [Getting Started:](#getting-started-1)
5658
- [Training](#training)
57-
- [Dreambooth](#dreambooth)
59+
- [Wan2.1](#wan-21-training)
60+
- [Flux](#flux-training)
61+
- [SDXL](#stable-diffusion-xl-training)
62+
- [SD 2 base](#stable-diffusion-2-base-training)
63+
- [SD 1.4](#stable-diffusion-14-training)
64+
- [Dreambooth](#dreambooth)
5865
- [Inference](#inference)
59-
- [LTX-Video](#ltx-video)
60-
- [Flux](#flux)
61-
- [Fused Attention for GPU:](#fused-attention-for-gpu)
62-
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
63-
- [Load Multiple LoRA](#load-multiple-lora)
64-
- [SDXL Lightning](#sdxl-lightning)
65-
- [ControlNet](#controlnet)
66+
- [LTX-Video](#ltx-video)
67+
- [Flux](#flux)
68+
- [Fused Attention for GPU](#fused-attention-for-gpu)
69+
- [SDXL](#stable-diffusion-xl)
70+
- [SD 2 base](#stable-diffusion-2-base)
71+
- [SD 2.1](#stable-diffusion-21)
72+
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
73+
- [Load Multiple LoRA](#load-multiple-lora)
74+
- [SDXL Lightning](#sdxl-lightning)
75+
- [ControlNet](#controlnet)
6676
- [Getting Started: Multihost development](#getting-started-multihost-development)
6777
- [Comparison to Alternatives](#comparison-to-alternatives)
6878
- [Development](#development)
@@ -81,7 +91,11 @@ For your first time running Maxdiffusion, we provide specific [instructions](doc
8191

8292
After installation completes, run the training script.
8393

84-
- **Flux**
94+
## Wan 2.1 Training
95+
96+
Foo
97+
98+
## Flux Training
8599

86100
Expected results on 1024 x 1024 images with flash attention and bfloat16:
87101

@@ -101,7 +115,7 @@ After installation completes, run the training script.
101115
python src/maxdiffusion/generate_flux_pipeline.py src/maxdiffusion/configs/base_flux_dev.yml run_name="test-flux-train" output_dir="gs://<your-gcs-bucket>/" jax_cache_dir="/tmp/jax_cache"
102116
```
103117

104-
- **Stable Diffusion XL**
118+
## Stable Diffusion XL Training
105119

106120
```bash
107121
export LIBTPU_INIT_ARGS=""
@@ -122,14 +136,14 @@ After installation completes, run the training script.
122136
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_xl.yml run_name="my_run" pretrained_model_name_or_path=<your_saved_checkpoint_path> from_pt=False attention=dot_product
123137
```
124138

125-
- **Stable Diffusion 2 base**
139+
## Stable Diffusion 2 base Training
126140

127141
```bash
128142
export LIBTPU_INIT_ARGS=""
129143
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name="my_run" jax_cache_dir=gs://your-bucket/cache_dir activations_dtype=float32 weights_dtype=float32 per_device_batch_size=2 precision=DEFAULT dataset_save_location=/tmp/my_dataset/ output_dir=gs://your-bucket/ attention=flash
130144
```
131145

132-
- **Stable Diffusion 1.4**
146+
## Stable Diffusion 1.4 Training
133147

134148
```bash
135149
export LIBTPU_INIT_ARGS=""
@@ -144,7 +158,7 @@ After installation completes, run the training script.
144158

145159
## Dreambooth
146160

147-
**Stable Diffusion 1.x,2.x**
161+
Supported models are **Stable Diffusion 1.x,2.x**
148162

149163
```bash
150164
python src/maxdiffusion/dreambooth/train_dreambooth.py src/maxdiffusion/configs/base14.yml class_data_dir=<your-class-dir> instance_data_dir=<your-instance-dir> instance_prompt="a photo of ohwx dog" class_prompt="photo of a dog" max_train_steps=150 jax_cache_dir=<your-cache-dir> class_prompt="a photo of a dog" activations_dtype=bfloat16 weights_dtype=float32 per_device_batch_size=1 enable_profiler=False precision=DEFAULT cache_dreambooth_dataset=False learning_rate=4e-6 num_class_images=100 run_name=<your-run-name> output_dir=gs://<your-bucket-name>
@@ -153,7 +167,7 @@ After installation completes, run the training script.
153167
## Inference
154168

155169
To generate images, run the following command:
156-
- **Stable Diffusion XL**
170+
## Stable Diffusion XL
157171

158172
Single and Multi host inference is supported with sharding annotations:
159173

@@ -167,25 +181,35 @@ To generate images, run the following command:
167181
python -m src.maxdiffusion.generate_sdxl_replicated
168182
```
169183

170-
- **Stable Diffusion 2 base**
184+
## Stable Diffusion 2 base
171185
```bash
172186
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_2_base.yml run_name="my_run"
187+
```
173188

174-
- **Stable Diffusion 2.1**
189+
## Stable Diffusion 2.1
175190
```bash
176191
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
177192
```
193+
178194
## LTX-Video
179-
- In the folder src/maxdiffusion/models/ltx_video/utils, run:
180-
```bash
181-
python convert_torch_weights_to_jax.py --ckpt_path [LOCAL DIRECTORY FOR WEIGHTS] --transformer_config_path ../ltxv-13B.json
182-
```
183-
- In the repo folder, run:
184-
```bash
185-
python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/ltxv-13B.json"
186-
```
187-
- Img2video Generation:
188-
Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above.
195+
In the folder src/maxdiffusion/models/ltx_video/utils, run:
196+
197+
```bash
198+
python convert_torch_weights_to_jax.py --ckpt_path [LOCAL DIRECTORY FOR WEIGHTS] --transformer_config_path ../ltxv-13B.json
199+
```
200+
201+
In the repo folder, run:
202+
```bash
203+
python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/ltxv-13B.json"
204+
```
205+
Img2video Generation:
206+
207+
Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above.
208+
209+
## Wan2.1
210+
211+
212+
189213
## Flux
190214

191215
First make sure you have permissions to access the Flux repos in Huggingface.
@@ -219,41 +243,41 @@ To generate images, run the following command:
219243
```bash
220244
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
221245
```
222-
## Fused Attention for GPU:
223-
Fused Attention for GPU is supported via TransformerEngine. Installation instructions:
246+
## Fused Attention for GPU:
247+
Fused Attention for GPU is supported via TransformerEngine. Installation instructions:
224248

225-
```bash
226-
cd maxdiffusion
227-
pip install -U "jax[cuda12]"
228-
pip install -r requirements.txt
229-
pip install --upgrade torch torchvision
230-
pip install "transformer_engine[jax]
231-
pip install .
232-
```
249+
```bash
250+
cd maxdiffusion
251+
pip install -U "jax[cuda12]"
252+
pip install -r requirements.txt
253+
pip install --upgrade torch torchvision
254+
pip install "transformer_engine[jax]
255+
pip install .
256+
```
233257
234-
Now run the command:
258+
Now run the command:
235259
236-
```bash
237-
NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu
238-
```
260+
```bash
261+
NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu
262+
```
239263
240-
## Flux LoRA
264+
## Flux LoRA
241265
242-
Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know.
266+
Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know.
243267
244-
Tested with [Amateur Photography](https://civitai.com/models/652699/amateur-photography-flux-dev) and [XLabs-AI](https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main) LoRA collection.
268+
Tested with [Amateur Photography](https://civitai.com/models/652699/amateur-photography-flux-dev) and [XLabs-AI](https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main) LoRA collection.
245269
246-
First download the LoRA file to a local directory, for example, `/home/jfacevedo/anime_lora.safetensors`. Then run as follows:
270+
First download the LoRA file to a local directory, for example, `/home/jfacevedo/anime_lora.safetensors`. Then run as follows:
247271
248-
```bash
249-
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}'
250-
```
272+
```bash
273+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}'
274+
```
251275
252-
Loading multiple LoRAs is supported as follows:
276+
Loading multiple LoRAs is supported as follows:
253277
254-
```bash
255-
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors", "/home/jfacevedo/amateurphoto-v6-forcu.safetensors"], "weight_name" : ["anime_lora.safetensors","amateurphoto-v6-forcu.safetensors"], "adapter_name" : ["anime","realistic"], "scale": [0.6, 0.6], "from_pt": ["true","true"]}'
256-
```
278+
```bash
279+
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 ici_data_parallelism=1 ici_fsdp_parallelism=-1 split_head_dim=True lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors", "/home/jfacevedo/amateurphoto-v6-forcu.safetensors"], "weight_name" : ["anime_lora.safetensors","amateurphoto-v6-forcu.safetensors"], "adapter_name" : ["anime","realistic"], "scale": [0.6, 0.6], "from_pt": ["true","true"]}'
280+
```
257281
258282
## Hyper SDXL LoRA
259283
@@ -265,36 +289,35 @@ To generate images, run the following command:
265289
266290
## Load Multiple LoRA
267291
268-
Supports loading multiple LoRAs for inference. Both from local or from HuggingFace hub.
292+
Supports loading multiple LoRAs for inference. Both from local or from HuggingFace hub.
269293
270-
```bash
271-
python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=30 do_classifier_free_guidance=True prompt="ultra detailed diagram blueprint of a papercut Sitting MaineCoon cat, wide canvas, ampereart, electrical diagram, bl3uprint, papercut" per_device_batch_size=1 diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/blueprintify-sd-xl-10.safetensors","TheLastBen/Papercut_SDXL"], "weight_name" : ["/home/jfacevedo/blueprintify-sd-xl-10.safetensors","papercut.safetensors"], "adapter_name" : ["blueprint","papercut"], "scale": [0.8, 0.7], "from_pt": ["true", "true"]}'
272-
```
294+
```bash
295+
python src/maxdiffusion/generate_sdxl.py src/maxdiffusion/configs/base_xl.yml run_name="test-lora" output_dir=/tmp/tmp/ jax_cache_dir=/tmp/cache_dir/ num_inference_steps=30 do_classifier_free_guidance=True prompt="ultra detailed diagram blueprint of a papercut Sitting MaineCoon cat, wide canvas, ampereart, electrical diagram, bl3uprint, papercut" per_device_batch_size=1 diffusion_scheduler_config='{"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}' lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/blueprintify-sd-xl-10.safetensors","TheLastBen/Papercut_SDXL"], "weight_name" : ["/home/jfacevedo/blueprintify-sd-xl-10.safetensors","papercut.safetensors"], "adapter_name" : ["blueprint","papercut"], "scale": [0.8, 0.7], "from_pt": ["true", "true"]}'
296+
```
273297
274298
## SDXL Lightning
275299
276300
Single and Multi host inference is supported with sharding annotations:
277301
278-
```bash
279-
python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl_lightning.yml run_name="my_run" lightning_repo="ByteDance/SDXL-Lightning" lightning_ckpt="sdxl_lightning_4step_unet.safetensors"
280-
```
302+
```bash
303+
python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl_lightning.yml run_name="my_run" lightning_repo="ByteDance/SDXL-Lightning" lightning_ckpt="sdxl_lightning_4step_unet.safetensors"
304+
```
281305
282306
## ControlNet
283307
284308
Might require installing extra libraries for opencv: `apt-get update && apt-get install ffmpeg libsm6 libxext6 -y`
285309
286-
- Stable Diffusion 1.4
310+
### Stable Diffusion 1.4
287311
288-
```bash
289-
python src/maxdiffusion/controlnet/generate_controlnet_replicated.py
290-
```
291-
292-
- Stable Diffusion XL
312+
```bash
313+
python src/maxdiffusion/controlnet/generate_controlnet_replicated.py
314+
```
293315
294-
```bash
295-
python src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py
296-
```
316+
### Stable Diffusion XL
297317
318+
```bash
319+
python src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py
320+
```
298321
299322
## Getting Started: Multihost development
300323
Multihost training for Stable Diffusion 2 base can be run using the following command:

0 commit comments

Comments
 (0)