Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2025/7/29`**: LTX-Video text2vid generation is now supported.
- **`2025/04/17`**: Flux Finetuning.
- **`2025/02/12`**: Flux LoRA for inference.
- **`2025/02/08`**: Flux schnell & dev inference.
Expand All @@ -41,6 +42,7 @@ MaxDiffusion supports
* Load Multiple LoRA (SDXL inference).
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
* Dreambooth training support for Stable Diffusion 1.x,2.x.
* LTX-Video text2vid (inference).


# Table of Contents
Expand All @@ -53,6 +55,7 @@ MaxDiffusion supports
- [Training](#training)
- [Dreambooth](#dreambooth)
- [Inference](#inference)
- [LTX-Video](#ltx-video)
- [Flux](#flux)
- [Fused Attention for GPU:](#fused-attention-for-gpu)
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
Expand Down Expand Up @@ -171,7 +174,16 @@ To generate images, run the following command:
```bash
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
```

## LTX-Video
- In the folder src/maxdiffusion/models/ltx_video/utils, run:
```bash
python convert_torch_weights_to_jax.py --ckpt_path [LOCAL DIRECTORY FOR WEIGHTS] --transformer_config_path ../xora_v1.2-13B-balanced-128.json
```
- In the repo folder, run:
```bash
python src/maxdiffusion/generate_ltx_video.py src/maxdiffusion/configs/ltx_video.yml output_dir="[SAME DIRECTORY]" config_path="src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json"
```
- Other generation parameters can be set in ltx_video.yml file.
## Flux

First make sure you have permissions to access the Flux repos in Huggingface.
Expand Down Expand Up @@ -205,7 +217,6 @@ To generate images, run the following command:
```bash
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
```

## Fused Attention for GPU:
Fused Attention for GPU is supported via TransformerEngine. Installation instructions:

Expand Down Expand Up @@ -322,3 +333,5 @@ This script will automatically format your code with `pyink` and help you identi


The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.


10 changes: 4 additions & 6 deletions src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,10 @@

def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, encoder_attention_segment_ids):
# Note: reference shape annotated for first pass default inference parameters
max_logging.log("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype) # (3, 256, 4096) float32
max_logging.log("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype) # (3, 3, 3072) float32
max_logging.log("latents.shape: ", latents.shape, latents.dtype) # (1, 3072, 128) float 32
max_logging.log(
"encoder_attention_segment_ids.shape: ", encoder_attention_segment_ids.shape, encoder_attention_segment_ids.dtype
) # (3, 256) int32
max_logging.log(f"prompts_embeds.shape: {prompt_embeds.shape}") # (3, 256, 4096) float32
max_logging.log(f"fractional_coords.shape: {fractional_coords.shape}") # (3, 3, 3072) float32
max_logging.log(f"latents.shape: {latents.shape}") # (1, 3072, 128) float 32
max_logging.log(f"encoder_attention_segment_ids.shape: {encoder_attention_segment_ids.shape}") # (3, 256) int32


class LTXVideoPipeline:
Expand Down
Loading