Skip to content

Commit 0577d3e

Browse files
committed
pipeline cleaned
1 parent 4c9be69 commit 0577d3e

3 files changed

Lines changed: 149 additions & 157 deletions

File tree

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,12 @@ frame_rate: 30
1919
max_sequence_length: 512
2020
sampler: "from_checkpoint"
2121

22-
23-
24-
25-
2622
# Generation parameters
27-
pipeline_type: None
28-
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
23+
pipeline_type: multi-scale
24+
prompt: "A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage."
2925
height: 512
3026
width: 512
31-
num_frames: 88 #344
27+
num_frames: 344 #344
3228
flow_shift: 5.0
3329
fps: 24
3430
downscale_factor: 0.6666666

src/maxdiffusion/generate_ltx_video.py

Lines changed: 11 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,12 @@
44
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
55
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline
66
from maxdiffusion import pyconfig
7-
import jax.numpy as jnp
87
from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
9-
from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler
108
from huggingface_hub import hf_hub_download
119
import imageio
1210
from datetime import datetime
13-
from maxdiffusion.utils import export_to_video
1411

1512
import os
16-
import json
1713
import torch
1814
from pathlib import Path
1915

@@ -62,25 +58,19 @@ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
6258
return "-".join(result)
6359

6460

65-
def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
66-
latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
67-
latent_upsampler.to(device)
68-
latent_upsampler.eval()
69-
return latent_upsampler
7061

7162

7263
def get_unique_filename(
7364
base: str,
7465
ext: str,
7566
prompt: str,
76-
seed: int,
7767
resolution: tuple[int, int, int],
7868
dir: Path,
7969
endswith=None,
8070
index_range=1000,
8171
) -> Path:
8272
base_filename = (
83-
f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
73+
f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
8474
)
8575
for i in range(index_range):
8676
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
@@ -94,55 +84,23 @@ def run(config):
9484
width_padded = ((config.width - 1) // 32 + 1) * 32
9585
num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
9686
padding = calculate_padding(config.height, config.width, height_padded, width_padded)
97-
# prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold
98-
# prompt_word_count = len(config.prompt.split())
99-
# enhance_prompt = (
100-
# prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold
101-
# )
102-
103-
seed = 10 # change this, generator in pytorch, used in prepare_latents
104-
generator = torch.Generator().manual_seed(seed)
105-
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=False)
106-
if config.pipeline_type == "multi-scale": # move this to pipeline file??
107-
spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path
108-
109-
if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path):
110-
spatial_upscaler_model_path = hf_hub_download(
111-
repo_id="Lightricks/LTX-Video",
112-
filename=spatial_upscaler_model_name_or_path,
113-
local_dir="/mnt/disks/diffusionproj",
114-
repo_type="model",
115-
)
116-
else:
117-
spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
118-
if not config.spatial_upscaler_model_path:
119-
raise ValueError(
120-
"spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
121-
)
122-
latent_upsampler = create_latent_upsampler(spatial_upscaler_model_path, "cpu") # device set to cpu for now
123-
pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
124-
stg_mode = config.stg_mode
125-
if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
126-
skip_layer_strategy = SkipLayerStrategy.AttentionValues
127-
elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
128-
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
129-
elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
130-
skip_layer_strategy = SkipLayerStrategy.Residual
131-
elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
132-
skip_layer_strategy = SkipLayerStrategy.TransformerBlock
133-
else:
134-
raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
135-
# images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded,
136-
# is_video=True, output_type='pt', generator=generator, guidance_scale = config.first_pass.guidance_scale, stg_scale = config.stg_scale, rescaling_scale = config.rescaling_scale, skip_initial_inference_steps= config.skip_initial_inference_steps, skip_final_inference_steps= config.skip_final_inference_steps, num_inference_steps = config.num_inference_steps,
137-
# guidance_timesteps = config.guidance_timesteps, cfg_star_rescale = config.cfg_star_rescale, skip_layer_strategy = None, skip_block_list=config.skip_block_list).images
87+
prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold
88+
prompt_word_count = len(config.prompt.split())
89+
enhance_prompt = (
90+
prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold
91+
)
92+
93+
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt)
94+
if config.pipeline_type == "multi-scale":
95+
pipeline = LTXMultiScalePipeline(pipeline)
13896
images = pipeline(
13997
height=height_padded,
14098
width=width_padded,
14199
num_frames=num_frames_padded,
142100
is_video=True,
143101
output_type="pt",
144-
generator=generator,
145102
config=config,
103+
enhance_prompt = False
146104
)
147105
(pad_left, pad_right, pad_top, pad_bottom) = padding
148106
pad_bottom = -pad_bottom
@@ -167,7 +125,6 @@ def run(config):
167125
f"image_output_{i}",
168126
".png",
169127
prompt=config.prompt,
170-
seed=seed,
171128
resolution=(height, width, config.num_frames),
172129
dir=output_dir,
173130
)
@@ -177,7 +134,6 @@ def run(config):
177134
f"video_output_{i}",
178135
".mp4",
179136
prompt=config.prompt,
180-
seed=seed,
181137
resolution=(height, width, config.num_frames),
182138
dir=output_dir,
183139
)

0 commit comments

Comments
 (0)