Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3776190
set up files for ltxvid
Serenagu525 Jun 26, 2025
13656fb
ltx-video-transformer-setup
Serenagu525 Jun 26, 2025
7bed4f9
formatting
Serenagu525 Jun 26, 2025
7e098c5
format fixed
Serenagu525 Jun 26, 2025
e18128c
transformer step and test
Serenagu525 Jun 30, 2025
1c55452
removed diffusers import
Serenagu525 Jun 30, 2025
fd4af91
fixed mesh
Serenagu525 Jun 30, 2025
5e17a62
changed path
Serenagu525 Jul 1, 2025
fc60b27
changed path
Serenagu525 Jul 1, 2025
3243535
changed config path
Serenagu525 Jul 1, 2025
e873a17
ruff check
Serenagu525 Jul 1, 2025
d06dee3
changed back pyconfig
Serenagu525 Jul 2, 2025
1ea6590
ruff check
Serenagu525 Jul 2, 2025
aa7befd
changed sharding back
Serenagu525 Jul 2, 2025
d9a3502
removed testing for now
Serenagu525 Jul 5, 2025
a1ad421
Update pyconfig.py
Serenagu525 Jul 5, 2025
615174f
Update max_utils.py
Serenagu525 Jul 5, 2025
7469c62
Update ltx_video.yml
Serenagu525 Jul 5, 2025
6de4424
Delete src/maxdiffusion/tests/ltx_vid_transformer_test_ref_pred
Serenagu525 Jul 5, 2025
18ec247
Delete src/maxdiffusion/tests/ltx_transformer_step_test.py
Serenagu525 Jul 5, 2025
8a043f6
sharding back
Serenagu525 Jul 9, 2025
35a3337
added test
Serenagu525 Jul 9, 2025
546ecab
ruff fixed
Serenagu525 Jul 9, 2025
12a247f
added header
Serenagu525 Jul 9, 2025
1062c72
license headers
Serenagu525 Jul 9, 2025
535c75e
exclude test
Serenagu525 Jul 9, 2025
7af151a
change base branch
Serenagu525 Jul 10, 2025
634591b
save now
Serenagu525 Jul 10, 2025
a272d08
load transformer error
Serenagu525 Jul 10, 2025
4bcffd1
later
Serenagu525 Jul 11, 2025
f5afa91
changed repeatable layer
Serenagu525 Jul 11, 2025
bb61ecb
functional
Serenagu525 Jul 11, 2025
7d4b2a9
moved upsampler
Serenagu525 Jul 11, 2025
972e316
initial cleaning
Serenagu525 Jul 11, 2025
c375471
multiscale pipeline
Serenagu525 Jul 16, 2025
f63a6fa
remove init
Serenagu525 Jul 16, 2025
b3874f5
new empty folders
Serenagu525 Jul 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
ruff check .
- name: PyTest
run: |
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest -x --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
[
Expand Down Expand Up @@ -453,6 +454,7 @@
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .models.ltx_video.transformers.transformer3d import Transformer3DModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (
Expand Down
7 changes: 5 additions & 2 deletions src/maxdiffusion/checkpointing/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,11 @@ def load_state_if_possible(
max_logging.log(f"restoring from this run's directory latest step {latest_step}")
try:
if not enable_single_replica_ckpt_restoring:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
if checkpoint_item == " ":
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
else:
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))

def map_to_pspec(data):
pspec = data.sharding.spec
Expand Down
101 changes: 101 additions & 0 deletions src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#hardware
hardware: 'tpu'
skip_jax_distributed_system: False

jax_cache_dir: ''
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'


run_name: ''
output_dir: 'ltx-video-output'
save_config_to_gcs: False
#Checkpoints
text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax"
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
frame_rate: 30
max_sequence_length: 512
sampler: "from_checkpoint"





# Generation parameters
pipeline_type: multi-scale
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
height: 512
width: 512
num_frames: 88 #344
flow_shift: 5.0
downscale_factor: 0.6666666
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
prompt_enhancement_words_threshold: 120
stg_mode: "attention_values"
decode_timestep: 0.05
decode_noise_scale: 0.025
models_dir: "/mnt/disks/diffusionproj" #where safetensor file is


first_pass:
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
stg_scale: [0, 0, 4, 4, 4, 2, 1]
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
num_inference_steps: 30
skip_final_inference_steps: 3
skip_initial_inference_steps: 0
cfg_star_rescale: True

second_pass:
guidance_scale: [1]
stg_scale: [1]
rescaling_scale: [1]
guidance_timesteps: [1.0]
skip_block_list: [27]
num_inference_steps: 30
skip_initial_inference_steps: 17
skip_final_inference_steps: 0
cfg_star_rescale: True

#parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_heads', 'fsdp'],
['activation_batch', ['data','fsdp']],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['norm', 'fsdp'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_in', 'fsdp']
]
data_sharding: [['data', 'fsdp', 'tensor']]
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1



learning_rate_schedule_steps: -1
max_train_steps: 500 #TODO: change this
pretrained_model_name_or_path: ''
unet_checkpoint: ''
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
per_device_batch_size: 1
compile_topology_num_slices: -1
quantization_local_shard_count: -1
jit_initializers: True
enable_single_replica_ckpt_restoring: False
146 changes: 146 additions & 0 deletions src/maxdiffusion/generate_ltx_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import numpy as np
from absl import app
from typing import Sequence
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline
from maxdiffusion import pyconfig
import imageio
from datetime import datetime
import os
import torch
from pathlib import Path


def calculate_padding(
source_height: int, source_width: int, target_height: int, target_width: int
) -> tuple[int, int, int, int]:

# Calculate total padding needed
pad_height = target_height - source_height
pad_width = target_width - source_width

# Calculate padding for each side
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top # Handles odd padding
pad_left = pad_width // 2
pad_right = pad_width - pad_left # Handles odd padding

# Return padded tensor
# Padding format is (left, right, top, bottom)
padding = (pad_left, pad_right, pad_top, pad_bottom)
return padding


def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
# Remove non-letters and convert to lowercase
clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace())

# Split into words
words = clean_text.split()

# Build result string keeping track of length
result = []
current_length = 0

for word in words:
# Add word length plus 1 for underscore (except for first word)
new_length = current_length + len(word)

if new_length <= max_len:
result.append(word)
current_length += len(word)
else:
break

return "-".join(result)


def get_unique_filename(
base: str,
ext: str,
prompt: str,
seed: int,
resolution: tuple[int, int, int],
dir: Path,
endswith=None,
index_range=1000,
) -> Path:
base_filename = (
f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
)
for i in range(index_range):
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
if not os.path.exists(filename):
return filename
raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.")


def run(config):
height_padded = ((config.height - 1) // 32 + 1) * 32
width_padded = ((config.width - 1) // 32 + 1) * 32
num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
padding = calculate_padding(config.height, config.width, height_padded, width_padded)

seed = 10
generator = torch.Generator().manual_seed(seed)
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=False)
pipeline = LTXMultiScalePipeline(pipeline)
images = pipeline(
height=height_padded,
width=width_padded,
num_frames=num_frames_padded,
output_type="pt",
generator=generator,
config=config,
)
(pad_left, pad_right, pad_top, pad_bottom) = padding
pad_bottom = -pad_bottom
pad_right = -pad_right
if pad_bottom == 0:
pad_bottom = images.shape[3]
if pad_right == 0:
pad_right = images.shape[4]
images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right]
output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
output_dir.mkdir(parents=True, exist_ok=True)
for i in range(images.shape[0]):
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy()
# Unnormalizing images to [0, 255] range
video_np = (video_np * 255).astype(np.uint8)
fps = config.frame_rate
height, width = video_np.shape[1:3]
# In case a single image is generated
if video_np.shape[0] == 1:
output_filename = get_unique_filename(
f"image_output_{i}",
".png",
prompt=config.prompt,
seed=seed,
resolution=(height, width, config.num_frames),
dir=output_dir,
)
imageio.imwrite(output_filename, video_np[0])
else:
output_filename = get_unique_filename(
f"video_output_{i}",
".mp4",
prompt=config.prompt,
seed=seed,
resolution=(height, width, config.num_frames),
dir=output_dir,
)
print(output_filename)
# Write video
with imageio.get_writer(output_filename, fps=fps) as video:
for frame in video_np:
video.append_data(frame)


def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
run(pyconfig.config)


if __name__ == "__main__":
app.run(main)
5 changes: 4 additions & 1 deletion src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,10 @@ def setup_initial_state(
config.enable_single_replica_ckpt_restoring,
)
if state:
state = state[checkpoint_item]
if checkpoint_item == " ":
state = state
else:
state = state[checkpoint_item]
if not state:
max_logging.log(f"Could not find the item in orbax, creating state...")
init_train_state_partial = functools.partial(
Expand Down
5 changes: 2 additions & 3 deletions src/maxdiffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License.

from typing import TYPE_CHECKING

from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available

from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available

_import_structure = {}

Expand All @@ -32,6 +30,7 @@
from .vae_flax import FlaxAutoencoderKL
from .lora import *
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .ltx_video.transformers.transformer3d import Transformer3DModel

else:
import sys
Expand Down
15 changes: 15 additions & 0 deletions src/maxdiffusion/models/ltx_video/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
Loading
Loading