1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import argparse
1514import math
1615import os
17- import random
1816from jax import Array
19- from datetime import datetime
20- from pathlib import Path
17+ from maxdiffusion .models .ltx_video .autoencoders .latent_upsampler import LatentUpsampler
2118from diffusers import AutoencoderKL
2219from typing import Optional , List , Union , Tuple
2320from einops import rearrange
2421import torch .nn .functional as F
2522from diffusers .utils .torch_utils import randn_tensor
26- # from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
27- import yaml
28- from transformers import (CLIPTokenizer , FlaxCLIPTextModel ,
29- T5EncoderModel , FlaxT5EncoderModel , AutoTokenizer )
30-
31-
32- import imageio
33- import json
34- import numpy as np
35- import torch
36- from safetensors import safe_open
37- from PIL import Image
3823from transformers import (
39- T5EncoderModel ,
40- T5Tokenizer ,
24+ FlaxT5EncoderModel ,
25+ AutoTokenizer ,
4126 AutoModelForCausalLM ,
4227 AutoProcessor ,
43- AutoTokenizer ,
44- )
28+ AutoTokenizer ,)
29+ import json
30+ import numpy as np
31+ import torch
4532from huggingface_hub import hf_hub_download
4633from maxdiffusion .models .ltx_video .autoencoders .causal_video_autoencoder import (
4734 CausalVideoAutoencoder ,
5542 normalize_latents ,
5643)
5744from diffusers .image_processor import VaeImageProcessor
58- from ltx_video .schedulers .rf import RectifiedFlowScheduler
5945from maxdiffusion .models .ltx_video .autoencoders .latent_upsampler import LatentUpsampler
60- import ltx_video .pipelines .crf_compressor as crf_compressor
6146from maxdiffusion .models .ltx_video .utils .prompt_enhance_utils import generate_cinematic_prompt
6247from math import e
6348from types import NoneType
6449from typing import Any , Dict
6550import numpy as np
66- import inspect
6751
6852import jax
6953import jax .numpy as jnp
7054from jax .sharding import Mesh , PartitionSpec as P
7155from typing import Optional , Union , List
72- import torch
73- from maxdiffusion .checkpointing import checkpointing_utils
74- from flax .linen import partitioning as nn_partitioning
7556from maxdiffusion .models .ltx_video .transformers .symmetric_patchifier import SymmetricPatchifier
76- from maxdiffusion .models .ltx_video .utils .skip_layer_strategy import SkipLayerStrategy
7757from ...pyconfig import HyperParameters
78- # from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState
7958from ...schedulers .scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler , RectifiedFlowSchedulerState
8059from ...max_utils import (
8160 create_device_mesh ,
8261 setup_initial_state ,
8362 get_memory_allocations
8463)
8564from maxdiffusion .models .ltx_video .transformers .transformer3d import Transformer3DModel
86- import os
8765import json
8866import functools
8967import orbax .checkpoint as ocp
90- import pickle
91-
92-
93- class PickleCheckpointHandler (ocp .CheckpointHandler ):
94- def save (self , directory : str , item , args = None ):
95- os .makedirs (directory , exist_ok = True )
96- with open (os .path .join (directory , 'checkpoint.pkl' ), 'wb' ) as f :
97- pickle .dump (item , f )
98-
99- def restore (self , directory : str , args = None ):
100- with open (os .path .join (directory , 'checkpoint.pkl' ), 'rb' ) as f :
101- return pickle .load (f )
102-
103- def structure (self , directory : str ):
104- return {} # not needed for simple pickle-based handling
105-
106-
107- def save_tensor_dict (tensor_dict , timestep ):
108- base_dir = os .path .dirname (__file__ )
109- local_path = os .path .join (base_dir , f"schedulerTest{ timestep } " )
110-
111- try :
112- torch .save (tensor_dict , local_path )
113- print (f"Dictionary of tensors saved to: { local_path } " )
114- except Exception as e :
115- print (f"Error saving dictionary: { e } " )
116- raise
117-
118-
119- def validate_transformer_inputs (prompt_embeds , fractional_coords , latents , noise_cond , segment_ids , encoder_attention_segment_ids ):
120- print ("prompts_embeds.shape: " , prompt_embeds .shape , prompt_embeds .dtype )
121- print ("fractional_coords.shape: " ,
122- fractional_coords .shape , fractional_coords .dtype )
123- print ("latents.shape: " , latents .shape , latents .dtype )
124- print ("noise_cond.shape: " , noise_cond .shape , noise_cond .dtype )
125- print ("noise_cond.shape: " , noise_cond .shape , noise_cond .dtype )
126- # print("segment_ids.shape: ", segment_ids.shape, segment_ids.dtype)
127- print ("encoder_attention_segment_ids.shape: " ,
128- encoder_attention_segment_ids .shape , encoder_attention_segment_ids .dtype )
129-
13068
13169def prepare_extra_step_kwargs (generator ):
13270 extra_step_kwargs = {}
@@ -817,7 +755,6 @@ def __call__(
817755 skip_initial_inference_steps : int = 0 ,
818756 skip_final_inference_steps : int = 0 ,
819757 cfg_star_rescale : bool = False ,
820- skip_layer_strategy : Optional [SkipLayerStrategy ] = None ,
821758 skip_block_list : Optional [Union [List [List [int ]], List [int ]]] = None ,
822759 ** kwargs ,
823760 ):
@@ -1076,7 +1013,6 @@ def __call__(
10761013 rescaling_scale = rescaling_scale ,
10771014 batch_size = batch_size ,
10781015 skip_layer_masks = skip_layer_masks ,
1079- skip_layer_strategy = skip_layer_strategy ,
10801016 cfg_star_rescale = cfg_star_rescale
10811017 )
10821018
@@ -1149,7 +1085,6 @@ def transformer_forward_pass( # need to jit this? wan didnt
11491085 segment_ids ,
11501086 encoder_attention_segment_ids ,
11511087 skip_layer_mask ,
1152- skip_layer_strategy ,
11531088):
11541089 noise_pred = transformer .apply (
11551090 {"params" : state .params },
@@ -1160,13 +1095,12 @@ def transformer_forward_pass( # need to jit this? wan didnt
11601095 segment_ids = segment_ids ,
11611096 encoder_attention_segment_ids = encoder_attention_segment_ids ,
11621097 skip_layer_mask = skip_layer_mask ,
1163- skip_layer_strategy = skip_layer_strategy
1164- ) # need .param here?
1098+ )
11651099 return noise_pred , state
11661100
11671101
11681102def run_inference (
1169- transformer_state , transformer , config , mesh , latents , fractional_cords , prompt_embeds , timestep , num_inference_steps , scheduler , segment_ids , encoder_attention_segment_ids , scheduler_state , do_classifier_free_guidance , num_conds , guidance_scale , do_spatio_temporal_guidance , stg_scale , do_rescaling , rescaling_scale , batch_size , skip_layer_masks , skip_layer_strategy , cfg_star_rescale
1103+ transformer_state , transformer , config , mesh , latents , fractional_cords , prompt_embeds , timestep , num_inference_steps , scheduler , segment_ids , encoder_attention_segment_ids , scheduler_state , do_classifier_free_guidance , num_conds , guidance_scale , do_spatio_temporal_guidance , stg_scale , do_rescaling , rescaling_scale , batch_size , skip_layer_masks ,cfg_star_rescale
11701104):
11711105 # do_classifier_free_guidance = guidance_scale > 1.0
11721106 # for step in range(num_inference_steps):
@@ -1206,7 +1140,7 @@ def run_inference(
12061140 skip_layer_masks [i ]
12071141 if skip_layer_masks is not None
12081142 else None
1209- ), skip_layer_strategy = skip_layer_strategy )
1143+ ))
12101144 # ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('data': 4, 'fsdp': 1, 'tensor': 1, 'fsdp_transpose': 1, 'expert': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'sequence': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(('data', 'fsdp'), None, None), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 4, but it is equal to 1 (full shape: (1, 1, 128))
12111145
12121146 # # latents = self.denoising
@@ -1294,6 +1228,31 @@ def adain_filter_latent(
12941228 return result
12951229
12961230class LTXMultiScalePipeline :
1231+
1232+ @classmethod
1233+ def load_latent_upsampler (cls , config ):
1234+ spatial_upscaler_model_name_or_path = config .spatial_upscaler_model_path
1235+
1236+ if spatial_upscaler_model_name_or_path and not os .path .isfile (
1237+ spatial_upscaler_model_name_or_path
1238+ ):
1239+ spatial_upscaler_model_path = hf_hub_download (
1240+ repo_id = "Lightricks/LTX-Video" ,
1241+ filename = spatial_upscaler_model_name_or_path ,
1242+ local_dir = "/mnt/disks/diffusionproj" ,
1243+ repo_type = "model" ,
1244+ )
1245+ else :
1246+ spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
1247+ if not config .spatial_upscaler_model_path :
1248+ raise ValueError (
1249+ "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
1250+ )
1251+ latent_upsampler = LatentUpsampler .from_pretrained (spatial_upscaler_model_path )
1252+ latent_upsampler .eval ()
1253+ return latent_upsampler
1254+
1255+
12971256 def _upsample_latents (
12981257 self , latest_upsampler : LatentUpsampler , latents : torch .Tensor
12991258 ):
@@ -1309,37 +1268,29 @@ def _upsample_latents(
13091268 return upsampled_latents
13101269
13111270 def __init__ (
1312- self , video_pipeline : LTXVideoPipeline , latent_upsampler : LatentUpsampler
1271+ self , video_pipeline : LTXVideoPipeline
13131272 ):
13141273 self .video_pipeline = video_pipeline
13151274 self .vae = video_pipeline .vae
1316- self .latent_upsampler = latent_upsampler
1317-
1275+
13181276 def __call__ (
13191277 self ,
13201278 height ,
13211279 width ,
13221280 num_frames ,
1323- is_video ,
13241281 output_type ,
13251282 generator ,
13261283 config
13271284 ) -> Any :
13281285
1286+ latent_upsampler = self .load_latent_upsampler (config )
13291287 original_output_type = output_type
1330- original_width = width
1331- original_height = height
1332- x_width = int (width * config .downscale_factor )
1333- downscaled_width = x_width - (x_width % self .video_pipeline .vae_scale_factor )
1334- x_height = int (height * config .downscale_factor )
1335- downscaled_height = x_height - (x_height % self .video_pipeline .vae_scale_factor )
1336- #use original height and width here to see
13371288 output_type = 'latent'
1338- result = self .video_pipeline (height = original_height , width = original_width , num_frames = num_frames ,
1289+ result = self .video_pipeline (height = height , width = width , num_frames = num_frames ,
13391290 is_video = True , output_type = output_type , generator = generator , guidance_scale = config .first_pass ["guidance_scale" ], stg_scale = config .first_pass ["stg_scale" ], rescaling_scale = config .first_pass ["rescaling_scale" ], skip_initial_inference_steps = config .first_pass ["skip_initial_inference_steps" ], skip_final_inference_steps = config .first_pass ["skip_final_inference_steps" ],
1340- num_inference_steps = config .first_pass ["num_inference_steps" ], guidance_timesteps = config .first_pass ["guidance_timesteps" ], cfg_star_rescale = config .first_pass ["cfg_star_rescale" ], skip_layer_strategy = None , skip_block_list = config .first_pass ["skip_block_list" ])
1291+ num_inference_steps = config .first_pass ["num_inference_steps" ], guidance_timesteps = config .first_pass ["guidance_timesteps" ], cfg_star_rescale = config .first_pass ["cfg_star_rescale" ], skip_block_list = config .first_pass ["skip_block_list" ])
13411292 latents = result
1342- upsampled_latents = self ._upsample_latents (self . latent_upsampler , latents )
1293+ upsampled_latents = self ._upsample_latents (latent_upsampler , latents )
13431294 upsampled_latents = adain_filter_latent (
13441295 latents = upsampled_latents , reference_latents = latents
13451296 )
@@ -1348,20 +1299,19 @@ def __call__(
13481299
13491300 latents = upsampled_latents
13501301 output_type = original_output_type
1351- width = downscaled_width * 2
1352- height = downscaled_height * 2
1302+
13531303
1354- result = self .video_pipeline (height = original_height * 2 , width = original_width * 2 , num_frames = num_frames ,
1304+ result = self .video_pipeline (height = height * 2 , width = width * 2 , num_frames = num_frames ,
13551305 is_video = True , output_type = output_type , latents = latents , generator = generator , guidance_scale = config .second_pass ["guidance_scale" ], stg_scale = config .second_pass ["stg_scale" ], rescaling_scale = config .second_pass ["rescaling_scale" ], skip_initial_inference_steps = config .second_pass ["skip_initial_inference_steps" ], skip_final_inference_steps = config .second_pass ["skip_final_inference_steps" ],
1356- num_inference_steps = config .second_pass ["num_inference_steps" ], guidance_timesteps = config .second_pass ["guidance_timesteps" ], cfg_star_rescale = config .second_pass ["cfg_star_rescale" ], skip_layer_strategy = None , skip_block_list = config .second_pass ["skip_block_list" ])
1306+ num_inference_steps = config .second_pass ["num_inference_steps" ], guidance_timesteps = config .second_pass ["guidance_timesteps" ], cfg_star_rescale = config .second_pass ["cfg_star_rescale" ], skip_block_list = config .second_pass ["skip_block_list" ])
13571307
13581308 if original_output_type != "latent" :
13591309 num_frames = result .shape [2 ]
13601310 videos = rearrange (result , "b c f h w -> (b f) c h w" )
13611311
13621312 videos = F .interpolate (
13631313 videos ,
1364- size = (original_height , original_width ),
1314+ size = (height , width ),
13651315 mode = "bilinear" ,
13661316 align_corners = False ,
13671317 )
0 commit comments