2727from transformers import (CLIPTokenizer , FlaxCLIPTextModel ,
2828 T5EncoderModel , FlaxT5EncoderModel , AutoTokenizer )
2929
30+ from maxdiffusion .models .ltx_video .autoencoders .latent_upsampler import LatentUpsampler
3031from torchax import interop
3132from torchax import default_env
3233import imageio
@@ -1360,6 +1361,28 @@ def adain_filter_latent(
13601361 return result
13611362
13621363class LTXMultiScalePipeline : ##figure these methods out
1364+
1365+ @classmethod
1366+ def load_latent_upsampler (cls , config ):
1367+ spatial_upscaler_model_name_or_path = config .spatial_upscaler_model_path
1368+
1369+ if spatial_upscaler_model_name_or_path and not os .path .isfile (spatial_upscaler_model_name_or_path ):
1370+ spatial_upscaler_model_path = hf_hub_download (
1371+ repo_id = "Lightricks/LTX-Video" ,
1372+ filename = spatial_upscaler_model_name_or_path ,
1373+ local_dir = config .models_dir ,
1374+ repo_type = "model" ,
1375+ )
1376+ else :
1377+ spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
1378+ if not config .spatial_upscaler_model_path :
1379+ raise ValueError (
1380+ "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
1381+ )
1382+ latent_upsampler = LatentUpsampler .from_pretrained (spatial_upscaler_model_path )
1383+ latent_upsampler .eval ()
1384+ return latent_upsampler
1385+
13631386 def _upsample_latents (
13641387 self , latest_upsampler : LatentUpsampler , latents : torch .Tensor
13651388 ):
@@ -1376,23 +1399,20 @@ def _upsample_latents(
13761399 return upsampled_latents
13771400
13781401 def __init__ (
1379- self , video_pipeline : LTXVideoPipeline , latent_upsampler : LatentUpsampler
1402+ self , video_pipeline : LTXVideoPipeline
13801403 ):
13811404 self .video_pipeline = video_pipeline
13821405 self .vae = video_pipeline .vae
1383- self .latent_upsampler = latent_upsampler
13841406
13851407 def __call__ (
13861408 self ,
13871409 height ,
13881410 width ,
13891411 num_frames ,
1390- is_video ,
13911412 output_type ,
13921413 generator ,
13931414 config
13941415 ) -> Any :
1395-
13961416 original_output_type = output_type
13971417 original_width = width
13981418 original_height = height
@@ -1407,7 +1427,8 @@ def __call__(
14071427 num_inference_steps = config .first_pass ["num_inference_steps" ], guidance_timesteps = config .first_pass ["guidance_timesteps" ], cfg_star_rescale = config .first_pass ["cfg_star_rescale" ], skip_layer_strategy = None , skip_block_list = config .first_pass ["skip_block_list" ])
14081428 latents = result
14091429 print ("done" )
1410- upsampled_latents = self ._upsample_latents (self .latent_upsampler , latents ) #convert back to pytorch here
1430+ latent_upsampler = self .load_latent_upsampler (config )
1431+ upsampled_latents = self ._upsample_latents (latent_upsampler , latents ) #convert back to pytorch here
14111432 ##maybe change this?
14121433 latents = torch .from_numpy (np .array (latents )) #.to(torch.device('cpu'))
14131434 upsampled_latents = torch .from_numpy (np .array (upsampled_latents )) #.to(torch.device('cpu'))
0 commit comments