Skip to content

Commit fd9eb11

Browse files
committed
error attribute weight already exist
1 parent b1e5b0c commit fd9eb11

2 files changed

Lines changed: 27 additions & 6 deletions

File tree

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ mesh_axes: ['data', 'fsdp', 'tensor']
6565
logical_axis_rules: [
6666
['batch', 'data'],
6767
['activation_heads', 'fsdp'],
68-
['activation_batch', ['data','fsdp']],
68+
['activation_batch', 'data'],
6969
['activation_kv', 'tensor'],
7070
['mlp','tensor'],
7171
['embed','fsdp'],

src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from transformers import (CLIPTokenizer, FlaxCLIPTextModel,
2828
T5EncoderModel, FlaxT5EncoderModel, AutoTokenizer)
2929

30+
from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler
3031
from torchax import interop
3132
from torchax import default_env
3233
import imageio
@@ -1360,6 +1361,28 @@ def adain_filter_latent(
13601361
return result
13611362

13621363
class LTXMultiScalePipeline: ##figure these methods out
1364+
1365+
@classmethod
1366+
def load_latent_upsampler(cls, config):
1367+
spatial_upscaler_model_name_or_path = config.spatial_upscaler_model_path
1368+
1369+
if spatial_upscaler_model_name_or_path and not os.path.isfile(spatial_upscaler_model_name_or_path):
1370+
spatial_upscaler_model_path = hf_hub_download(
1371+
repo_id="Lightricks/LTX-Video",
1372+
filename=spatial_upscaler_model_name_or_path,
1373+
local_dir=config.models_dir,
1374+
repo_type="model",
1375+
)
1376+
else:
1377+
spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
1378+
if not config.spatial_upscaler_model_path:
1379+
raise ValueError(
1380+
"spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
1381+
)
1382+
latent_upsampler = LatentUpsampler.from_pretrained(spatial_upscaler_model_path)
1383+
latent_upsampler.eval()
1384+
return latent_upsampler
1385+
13631386
def _upsample_latents(
13641387
self, latest_upsampler: LatentUpsampler, latents: torch.Tensor
13651388
):
@@ -1376,23 +1399,20 @@ def _upsample_latents(
13761399
return upsampled_latents
13771400

13781401
def __init__(
1379-
self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler
1402+
self, video_pipeline: LTXVideoPipeline
13801403
):
13811404
self.video_pipeline = video_pipeline
13821405
self.vae = video_pipeline.vae
1383-
self.latent_upsampler = latent_upsampler
13841406

13851407
def __call__(
13861408
self,
13871409
height,
13881410
width,
13891411
num_frames,
1390-
is_video,
13911412
output_type,
13921413
generator,
13931414
config
13941415
) -> Any:
1395-
13961416
original_output_type = output_type
13971417
original_width = width
13981418
original_height = height
@@ -1407,7 +1427,8 @@ def __call__(
14071427
num_inference_steps = config.first_pass["num_inference_steps"], guidance_timesteps = config.first_pass["guidance_timesteps"], cfg_star_rescale = config.first_pass["cfg_star_rescale"], skip_layer_strategy = None, skip_block_list=config.first_pass["skip_block_list"])
14081428
latents = result
14091429
print("done")
1410-
upsampled_latents = self._upsample_latents(self.latent_upsampler, latents) #convert back to pytorch here
1430+
latent_upsampler = self.load_latent_upsampler(config)
1431+
upsampled_latents = self._upsample_latents(latent_upsampler, latents) #convert back to pytorch here
14111432
##maybe change this?
14121433
latents = torch.from_numpy(np.array(latents)) #.to(torch.device('cpu'))
14131434
upsampled_latents = torch.from_numpy(np.array(upsampled_latents)) #.to(torch.device('cpu'))

0 commit comments

Comments
 (0)