You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* add support for flux vae. ~ wip
* test for flux vae both encoding and decoding.
* add clip text encoder test
* remove transformers inside maxdiffusion, add transformers dependency. Start creating generation code for flux.
* add double block to flux
* forward pass for single double block.
* trying to use scan.
* add single stream block
* finish transformer
* convert pt weights to flax and load transformer state.
* apply fsdp sharding, do one forward pass in the transformer.
* wip - generate fn
* working loop, bad generation
* e2e, encoder offloading.
* support both dev and schnell loading. Images still incorrect.
* flux schnell working
* removed unused code.
* support dev
* add sentencepiece requirement
* fix repeated double and single blocks.
* optimized flash block sizes for trillium.
* clean up code and lint
* fix sdxl generate smoke tests.
* fix rest of unit tests.
* update readme and some dependencies.
* remove unused dependencies.
* initial lora implementation for flux
* adding another format lora support.
* Support other format loras. update readme. Run code_style.
* ruff
* fix typo in readme.
* Added FA support for GPUs
* ruff and code_style
* fixed final comments
* Correcting small misstake due to missunderstanding
---------
Co-authored-by: Juan Acevedo <jfacevedo@google.com>
Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
-[Comparison to Alternatives](#comparison-to-alternatives)
63
+
-[Development](#development)
59
64
60
65
# Getting Started
61
66
@@ -171,6 +176,24 @@ To generate images, run the following command:
171
176
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
172
177
```
173
178
179
+
## Flash Attention for GPU:
180
+
Flash Attention for GPU is supported via TransformerEngine. Installation instructions:
181
+
182
+
```bash
183
+
cd maxdiffusion
184
+
pip install -U "jax[cuda12]"
185
+
pip install -r requirements.txt
186
+
pip install --upgrade torch torchvision
187
+
pip install "transformer_engine[jax]
188
+
pip install .
189
+
```
190
+
191
+
Now run the command:
192
+
193
+
```bash
194
+
NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu
195
+
```
196
+
174
197
## Flux LoRA
175
198
176
199
Disclaimer: not all LoRA formats have been tested. If there is a specific LoRA that doesn't load, please let us know.
0 commit comments