|
4 | 4 | from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline |
5 | 5 | from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline |
6 | 6 | from maxdiffusion import pyconfig |
7 | | -import jax.numpy as jnp |
8 | | -from maxdiffusion.models.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy |
9 | 7 | from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler |
10 | 8 | from huggingface_hub import hf_hub_download |
11 | 9 | import imageio |
12 | 10 | from datetime import datetime |
13 | | -from maxdiffusion.utils import export_to_video |
14 | 11 |
|
15 | 12 | import os |
16 | | -import json |
17 | 13 | import torch |
18 | 14 | from pathlib import Path |
19 | 15 |
|
@@ -96,52 +92,12 @@ def run(config): |
96 | 92 | num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 |
97 | 93 | padding = calculate_padding( |
98 | 94 | config.height, config.width, height_padded, width_padded) |
99 | | - # prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold |
100 | | - # prompt_word_count = len(config.prompt.split()) |
101 | | - # enhance_prompt = ( |
102 | | - # prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold |
103 | | - # ) |
104 | 95 |
|
105 | | - seed = 10 # change this, generator in pytorch, used in prepare_latents |
| 96 | + seed = 10 |
106 | 97 | generator = torch.Generator().manual_seed(seed) |
107 | 98 | pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt = False) |
108 | | - if config.pipeline_type == "multi-scale": #move this to pipeline file?? |
109 | | - spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path |
110 | | - |
111 | | - if spatial_upscaler_model_name_or_path and not os.path.isfile( |
112 | | - spatial_upscaler_model_name_or_path |
113 | | - ): |
114 | | - spatial_upscaler_model_path = hf_hub_download( |
115 | | - repo_id="Lightricks/LTX-Video", |
116 | | - filename=spatial_upscaler_model_name_or_path, |
117 | | - local_dir= "/mnt/disks/diffusionproj", |
118 | | - repo_type="model", |
119 | | - ) |
120 | | - else: |
121 | | - spatial_upscaler_model_path = spatial_upscaler_model_name_or_path |
122 | | - if not config.spatial_upscaler_model_path: |
123 | | - raise ValueError( |
124 | | - "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" |
125 | | - ) |
126 | | - latent_upsampler = create_latent_upsampler( |
127 | | - spatial_upscaler_model_path, "cpu" #device set to cpu for now |
128 | | - ) |
129 | | - pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) |
130 | | - stg_mode = config.stg_mode |
131 | | - if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": |
132 | | - skip_layer_strategy = SkipLayerStrategy.AttentionValues |
133 | | - elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": |
134 | | - skip_layer_strategy = SkipLayerStrategy.AttentionSkip |
135 | | - elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": |
136 | | - skip_layer_strategy = SkipLayerStrategy.Residual |
137 | | - elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": |
138 | | - skip_layer_strategy = SkipLayerStrategy.TransformerBlock |
139 | | - else: |
140 | | - raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}") |
141 | | - # images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, |
142 | | - # is_video=True, output_type='pt', generator=generator, guidance_scale = config.first_pass.guidance_scale, stg_scale = config.stg_scale, rescaling_scale = config.rescaling_scale, skip_initial_inference_steps= config.skip_initial_inference_steps, skip_final_inference_steps= config.skip_final_inference_steps, num_inference_steps = config.num_inference_steps, |
143 | | - # guidance_timesteps = config.guidance_timesteps, cfg_star_rescale = config.cfg_star_rescale, skip_layer_strategy = None, skip_block_list=config.skip_block_list).images |
144 | | - images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, is_video=True, output_type='pt', generator=generator, config = config) |
| 99 | + pipeline = LTXMultiScalePipeline(pipeline) |
| 100 | + images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, output_type='pt', generator=generator, config = config) |
145 | 101 | (pad_left, pad_right, pad_top, pad_bottom) = padding |
146 | 102 | pad_bottom = -pad_bottom |
147 | 103 | pad_right = -pad_right |
|
0 commit comments