Skip to content

Commit 4c9be69

Browse files
committed
baseline pipeline cleaned
1 parent fd9eb11 commit 4c9be69

25 files changed

Lines changed: 4221 additions & 4894 deletions

Whitespace-only changes.

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ activations_dtype: 'bfloat16'
1010
run_name: ''
1111
output_dir: 'ltx-video-output'
1212
save_config_to_gcs: False
13+
1314
#Checkpoints
1415
text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax"
1516
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
@@ -23,12 +24,13 @@ sampler: "from_checkpoint"
2324

2425

2526
# Generation parameters
26-
pipeline_type: multi-scale
27+
pipeline_type: None
2728
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
2829
height: 512
2930
width: 512
3031
num_frames: 88 #344
3132
flow_shift: 5.0
33+
fps: 24
3234
downscale_factor: 0.6666666
3335
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
3436
prompt_enhancement_words_threshold: 120

src/maxdiffusion/generate_ltx_video.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,16 @@
44
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
55
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline
66
from maxdiffusion import pyconfig
7+
import jax.numpy as jnp
8+
from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
9+
from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler
10+
from huggingface_hub import hf_hub_download
711
import imageio
812
from datetime import datetime
13+
from maxdiffusion.utils import export_to_video
14+
915
import os
16+
import json
1017
import torch
1118
from pathlib import Path
1219

@@ -55,6 +62,13 @@ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
5562
return "-".join(result)
5663

5764

65+
def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
66+
latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
67+
latent_upsampler.to(device)
68+
latent_upsampler.eval()
69+
return latent_upsampler
70+
71+
5872
def get_unique_filename(
5973
base: str,
6074
ext: str,
@@ -80,15 +94,52 @@ def run(config):
8094
width_padded = ((config.width - 1) // 32 + 1) * 32
8195
num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
8296
padding = calculate_padding(config.height, config.width, height_padded, width_padded)
97+
# prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold
98+
# prompt_word_count = len(config.prompt.split())
99+
# enhance_prompt = (
100+
# prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold
101+
# )
83102

84-
seed = 10
103+
seed = 10 # change this, generator in pytorch, used in prepare_latents
85104
generator = torch.Generator().manual_seed(seed)
86105
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=False)
87-
pipeline = LTXMultiScalePipeline(pipeline)
106+
if config.pipeline_type == "multi-scale": # move this to pipeline file??
107+
spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path
108+
109+
if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path):
110+
spatial_upscaler_model_path = hf_hub_download(
111+
repo_id="Lightricks/LTX-Video",
112+
filename=spatial_upscaler_model_name_or_path,
113+
local_dir="/mnt/disks/diffusionproj",
114+
repo_type="model",
115+
)
116+
else:
117+
spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
118+
if not config.spatial_upscaler_model_path:
119+
raise ValueError(
120+
"spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
121+
)
122+
latent_upsampler = create_latent_upsampler(spatial_upscaler_model_path, "cpu") # device set to cpu for now
123+
pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
124+
stg_mode = config.stg_mode
125+
if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
126+
skip_layer_strategy = SkipLayerStrategy.AttentionValues
127+
elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
128+
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
129+
elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
130+
skip_layer_strategy = SkipLayerStrategy.Residual
131+
elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
132+
skip_layer_strategy = SkipLayerStrategy.TransformerBlock
133+
else:
134+
raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
135+
# images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded,
136+
# is_video=True, output_type='pt', generator=generator, guidance_scale = config.first_pass.guidance_scale, stg_scale = config.stg_scale, rescaling_scale = config.rescaling_scale, skip_initial_inference_steps= config.skip_initial_inference_steps, skip_final_inference_steps= config.skip_final_inference_steps, num_inference_steps = config.num_inference_steps,
137+
# guidance_timesteps = config.guidance_timesteps, cfg_star_rescale = config.cfg_star_rescale, skip_layer_strategy = None, skip_block_list=config.skip_block_list).images
88138
images = pipeline(
89139
height=height_padded,
90140
width=width_padded,
91141
num_frames=num_frames_padded,
142+
is_video=True,
92143
output_type="pt",
93144
generator=generator,
94145
config=config,

src/maxdiffusion/models/ltx_video/autoencoders/__init__.py

Whitespace-only changes.

src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,59 +5,54 @@
55

66

77
class CausalConv3d(nn.Module):
8-
def __init__(
9-
self,
8+
9+
def __init__(
10+
self,
11+
in_channels,
12+
out_channels,
13+
kernel_size: int = 3,
14+
stride: Union[int, Tuple[int]] = 1,
15+
dilation: int = 1,
16+
groups: int = 1,
17+
spatial_padding_mode: str = "zeros",
18+
**kwargs,
19+
):
20+
super().__init__()
21+
22+
self.in_channels = in_channels
23+
self.out_channels = out_channels
24+
25+
kernel_size = (kernel_size, kernel_size, kernel_size)
26+
self.time_kernel_size = kernel_size[0]
27+
28+
dilation = (dilation, 1, 1)
29+
30+
height_pad = kernel_size[1] // 2
31+
width_pad = kernel_size[2] // 2
32+
padding = (0, height_pad, width_pad)
33+
34+
self.conv = nn.Conv3d(
1035
in_channels,
1136
out_channels,
12-
kernel_size: int = 3,
13-
stride: Union[int, Tuple[int]] = 1,
14-
dilation: int = 1,
15-
groups: int = 1,
16-
spatial_padding_mode: str = "zeros",
17-
**kwargs,
18-
):
19-
super().__init__()
20-
21-
self.in_channels = in_channels
22-
self.out_channels = out_channels
23-
24-
kernel_size = (kernel_size, kernel_size, kernel_size)
25-
self.time_kernel_size = kernel_size[0]
26-
27-
dilation = (dilation, 1, 1)
28-
29-
height_pad = kernel_size[1] // 2
30-
width_pad = kernel_size[2] // 2
31-
padding = (0, height_pad, width_pad)
32-
33-
self.conv = nn.Conv3d(
34-
in_channels,
35-
out_channels,
36-
kernel_size,
37-
stride=stride,
38-
dilation=dilation,
39-
padding=padding,
40-
padding_mode=spatial_padding_mode,
41-
groups=groups,
42-
)
43-
44-
def forward(self, x, causal: bool = True):
45-
if causal:
46-
first_frame_pad = x[:, :, :1, :, :].repeat(
47-
(1, 1, self.time_kernel_size - 1, 1, 1)
48-
)
49-
x = torch.concatenate((first_frame_pad, x), dim=2)
50-
else:
51-
first_frame_pad = x[:, :, :1, :, :].repeat(
52-
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
53-
)
54-
last_frame_pad = x[:, :, -1:, :, :].repeat(
55-
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
56-
)
57-
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
58-
x = self.conv(x)
59-
return x
60-
61-
@property
62-
def weight(self):
63-
return self.conv.weight
37+
kernel_size,
38+
stride=stride,
39+
dilation=dilation,
40+
padding=padding,
41+
padding_mode=spatial_padding_mode,
42+
groups=groups,
43+
)
44+
45+
def forward(self, x, causal: bool = True):
46+
if causal:
47+
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1))
48+
x = torch.concatenate((first_frame_pad, x), dim=2)
49+
else:
50+
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
51+
last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
52+
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
53+
x = self.conv(x)
54+
return x
55+
56+
@property
57+
def weight(self):
58+
return self.conv.weight

0 commit comments

Comments
 (0)