Skip to content

Commit 972e316

Browse files
committed
initial cleaning
1 parent 7d4b2a9 commit 972e316

4 files changed

Lines changed: 191 additions & 369 deletions

File tree

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,10 @@ flow_shift: 5.0
3232
downscale_factor: 0.6666666
3333
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
3434
prompt_enhancement_words_threshold: 120
35-
# guidance_scale: [1, 1, 6, 8, 6, 1, 1] #4.5
36-
# stg_scale: [0, 0, 4, 4, 4, 2, 1] #1.0
37-
# rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] #0.7
38-
# num_inference_steps: 30
39-
# skip_final_inference_steps: 3
40-
# skip_initial_inference_steps: 0
41-
# guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
42-
# skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
4335
stg_mode: "attention_values"
4436
decode_timestep: 0.05
4537
decode_noise_scale: 0.025
46-
# cfg_star_rescale: True
38+
models_dir: "/mnt/disks/diffusionproj" #where safetensor file is
4739

4840

4941
first_pass:

src/maxdiffusion/generate_ltx_video.py

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,12 @@
44
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
55
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline
66
from maxdiffusion import pyconfig
7-
import jax.numpy as jnp
8-
from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
97
from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler
108
from huggingface_hub import hf_hub_download
119
import imageio
1210
from datetime import datetime
13-
from maxdiffusion.utils import export_to_video
1411

1512
import os
16-
import json
1713
import torch
1814
from pathlib import Path
1915

@@ -96,52 +92,12 @@ def run(config):
9692
num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
9793
padding = calculate_padding(
9894
config.height, config.width, height_padded, width_padded)
99-
# prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold
100-
# prompt_word_count = len(config.prompt.split())
101-
# enhance_prompt = (
102-
# prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold
103-
# )
10495

105-
seed = 10 # change this, generator in pytorch, used in prepare_latents
96+
seed = 10
10697
generator = torch.Generator().manual_seed(seed)
10798
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt = False)
108-
if config.pipeline_type == "multi-scale": #move this to pipeline file??
109-
spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path
110-
111-
if spatial_upscaler_model_name_or_path and not os.path.isfile(
112-
spatial_upscaler_model_name_or_path
113-
):
114-
spatial_upscaler_model_path = hf_hub_download(
115-
repo_id="Lightricks/LTX-Video",
116-
filename=spatial_upscaler_model_name_or_path,
117-
local_dir= "/mnt/disks/diffusionproj",
118-
repo_type="model",
119-
)
120-
else:
121-
spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
122-
if not config.spatial_upscaler_model_path:
123-
raise ValueError(
124-
"spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
125-
)
126-
latent_upsampler = create_latent_upsampler(
127-
spatial_upscaler_model_path, "cpu" #device set to cpu for now
128-
)
129-
pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
130-
stg_mode = config.stg_mode
131-
if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
132-
skip_layer_strategy = SkipLayerStrategy.AttentionValues
133-
elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
134-
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
135-
elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
136-
skip_layer_strategy = SkipLayerStrategy.Residual
137-
elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
138-
skip_layer_strategy = SkipLayerStrategy.TransformerBlock
139-
else:
140-
raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
141-
# images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded,
142-
# is_video=True, output_type='pt', generator=generator, guidance_scale = config.first_pass.guidance_scale, stg_scale = config.stg_scale, rescaling_scale = config.rescaling_scale, skip_initial_inference_steps= config.skip_initial_inference_steps, skip_final_inference_steps= config.skip_final_inference_steps, num_inference_steps = config.num_inference_steps,
143-
# guidance_timesteps = config.guidance_timesteps, cfg_star_rescale = config.cfg_star_rescale, skip_layer_strategy = None, skip_block_list=config.skip_block_list).images
144-
images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, is_video=True, output_type='pt', generator=generator, config = config)
99+
pipeline = LTXMultiScalePipeline(pipeline)
100+
images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, output_type='pt', generator=generator, config = config)
145101
(pad_left, pad_right, pad_top, pad_bottom) = padding
146102
pad_bottom = -pad_bottom
147103
pad_right = -pad_right

src/maxdiffusion/models/ltx_video/repeatable_layer.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def __call__(self, carry: Tuple[jax.Array, jax.Array], *block_args) -> Tuple[Tup
2525

2626
mod = self.module(*self.module_init_args, **self.module_init_kwargs)
2727

28-
# block_args are the static arguments passed to each individual block
29-
output_data = mod(index_input, data_input, *block_args) # Pass block_args to the module
28+
output_data = mod(index_input, data_input, *block_args) # Pass index_input to facilitate skip layers
3029

3130
next_index = index_input + 1
3231
new_carry = (output_data, next_index)
@@ -76,14 +75,14 @@ class RepeatableLayer(nn.Module):
7675
"""
7776

7877
@nn.compact
79-
def __call__(self, *args): # args is now the full input to RepeatableLayer
78+
def __call__(self, *args):
8079
if not args:
8180
raise ValueError("RepeatableLayer expects at least one argument for initial data input.")
8281

83-
initial_data_input = args[0] # The first element is your main data input
84-
static_block_args = args[1:] # Any subsequent elements are static args for each block
82+
initial_data_input = args[0]
83+
static_block_args = args[1:]
8584

86-
initial_index = jnp.array(0, dtype=jnp.int32)
85+
initial_index = jnp.array(0, dtype=jnp.int32) #index of current transformer block
8786

8887
scan_kwargs = {}
8988
if self.pspec_name is not None:
@@ -92,9 +91,6 @@ def __call__(self, *args): # args is now the full input to RepeatableLayer
9291
initializing = self.is_mutable_collection("params")
9392
params_spec = self.param_scan_axis if initializing else partitioning.ScanIn(self.param_scan_axis)
9493

95-
# in_axes for the scanned function (RepeatableCarryBlock.__call__):
96-
# 1. The 'carry' tuple ((0, 0))
97-
# 2. Then, nn.broadcast for each of the `static_block_args`
9894
in_axes_for_scan = (nn.broadcast,) * (len(args)-1)
9995

10096
scan_fn = nn.scan(
@@ -117,5 +113,4 @@ def __call__(self, *args): # args is now the full input to RepeatableLayer
117113
# Call wrapped_function with the initial carry tuple and the static_block_args
118114
(final_data, final_index), _ = wrapped_function((initial_data_input, initial_index), *static_block_args)
119115

120-
# Typically, you only want the final data output from the sequence of layers
121116
return final_data

0 commit comments

Comments
 (0)