44from maxdiffusion .pipelines .ltx_video .ltx_video_pipeline import LTXVideoPipeline
55from maxdiffusion .pipelines .ltx_video .ltx_video_pipeline import LTXMultiScalePipeline
66from maxdiffusion import pyconfig
7- from maxdiffusion .models .ltx_video .autoencoders .latent_upsampler import LatentUpsampler
8- from huggingface_hub import hf_hub_download
97import imageio
108from datetime import datetime
11-
129import os
1310import torch
1411from pathlib import Path
@@ -18,52 +15,45 @@ def calculate_padding(
1815 source_height : int , source_width : int , target_height : int , target_width : int
1916) -> tuple [int , int , int , int ]:
2017
21- # Calculate total padding needed
22- pad_height = target_height - source_height
23- pad_width = target_width - source_width
18+ # Calculate total padding needed
19+ pad_height = target_height - source_height
20+ pad_width = target_width - source_width
2421
25- # Calculate padding for each side
26- pad_top = pad_height // 2
27- pad_bottom = pad_height - pad_top # Handles odd padding
28- pad_left = pad_width // 2
29- pad_right = pad_width - pad_left # Handles odd padding
22+ # Calculate padding for each side
23+ pad_top = pad_height // 2
24+ pad_bottom = pad_height - pad_top # Handles odd padding
25+ pad_left = pad_width // 2
26+ pad_right = pad_width - pad_left # Handles odd padding
3027
31- # Return padded tensor
32- # Padding format is (left, right, top, bottom)
33- padding = (pad_left , pad_right , pad_top , pad_bottom )
34- return padding
28+ # Return padded tensor
29+ # Padding format is (left, right, top, bottom)
30+ padding = (pad_left , pad_right , pad_top , pad_bottom )
31+ return padding
3532
3633
3734def convert_prompt_to_filename (text : str , max_len : int = 20 ) -> str :
38- # Remove non-letters and convert to lowercase
39- clean_text = "" .join (
40- char .lower () for char in text if char .isalpha () or char .isspace ()
41- )
35+ # Remove non-letters and convert to lowercase
36+ clean_text = "" .join (char .lower () for char in text if char .isalpha () or char .isspace ())
4237
43- # Split into words
44- words = clean_text .split ()
38+ # Split into words
39+ words = clean_text .split ()
4540
46- # Build result string keeping track of length
47- result = []
48- current_length = 0
41+ # Build result string keeping track of length
42+ result = []
43+ current_length = 0
4944
50- for word in words :
51- # Add word length plus 1 for underscore (except for first word)
52- new_length = current_length + len (word )
45+ for word in words :
46+ # Add word length plus 1 for underscore (except for first word)
47+ new_length = current_length + len (word )
5348
54- if new_length <= max_len :
55- result .append (word )
56- current_length += len (word )
57- else :
58- break
49+ if new_length <= max_len :
50+ result .append (word )
51+ current_length += len (word )
52+ else :
53+ break
5954
60- return "-" .join (result )
55+ return "-" .join (result )
6156
62- def create_latent_upsampler (latent_upsampler_model_path : str , device : str ):
63- latent_upsampler = LatentUpsampler .from_pretrained (latent_upsampler_model_path )
64- latent_upsampler .to (device )
65- latent_upsampler .eval ()
66- return latent_upsampler
6757
6858def get_unique_filename (
6959 base : str ,
@@ -75,78 +65,82 @@ def get_unique_filename(
7565 endswith = None ,
7666 index_range = 1000 ,
7767) -> Path :
78- base_filename = f"{ base } _{ convert_prompt_to_filename (prompt , max_len = 30 )} _{ seed } _{ resolution [0 ]} x{ resolution [1 ]} x{ resolution [2 ]} "
79- for i in range (index_range ):
80- filename = dir / \
81- f"{ base_filename } _{ i } { endswith if endswith else '' } { ext } "
82- if not os .path .exists (filename ):
83- return filename
84- raise FileExistsError (
85- f"Could not find a unique filename after { index_range } attempts."
86- )
68+ base_filename = (
69+ f"{ base } _{ convert_prompt_to_filename (prompt , max_len = 30 )} _{ seed } _{ resolution [0 ]} x{ resolution [1 ]} x{ resolution [2 ]} "
70+ )
71+ for i in range (index_range ):
72+ filename = dir / f"{ base_filename } _{ i } { endswith if endswith else '' } { ext } "
73+ if not os .path .exists (filename ):
74+ return filename
75+ raise FileExistsError (f"Could not find a unique filename after { index_range } attempts." )
8776
8877
8978def run (config ):
90- height_padded = ((config .height - 1 ) // 32 + 1 ) * 32
91- width_padded = ((config .width - 1 ) // 32 + 1 ) * 32
92- num_frames_padded = ((config .num_frames - 2 ) // 8 + 1 ) * 8 + 1
93- padding = calculate_padding (
94- config .height , config .width , height_padded , width_padded )
95-
96- seed = 10
97- generator = torch .Generator ().manual_seed (seed )
98- pipeline = LTXVideoPipeline .from_pretrained (config , enhance_prompt = False )
99- pipeline = LTXMultiScalePipeline (pipeline )
100- images = pipeline (height = height_padded , width = width_padded , num_frames = num_frames_padded , output_type = 'pt' , generator = generator , config = config )
101- (pad_left , pad_right , pad_top , pad_bottom ) = padding
102- pad_bottom = - pad_bottom
103- pad_right = - pad_right
104- if pad_bottom == 0 :
105- pad_bottom = images .shape [3 ]
106- if pad_right == 0 :
107- pad_right = images .shape [4 ]
108- images = images [:, :, :config .num_frames ,
109- pad_top :pad_bottom , pad_left :pad_right ]
110- output_dir = Path (f"outputs/{ datetime .today ().strftime ('%Y-%m-%d' )} " )
111- output_dir .mkdir (parents = True , exist_ok = True )
112- for i in range (images .shape [0 ]):
113- # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
114- video_np = images [i ].permute (1 , 2 , 3 , 0 ).detach ().float ().numpy ()
115- # Unnormalizing images to [0, 255] range
116- video_np = (video_np * 255 ).astype (np .uint8 )
117- fps = config .frame_rate
118- height , width = video_np .shape [1 :3 ]
119- # In case a single image is generated
120- if video_np .shape [0 ] == 1 :
121- output_filename = get_unique_filename (
122- f"image_output_{ i } " ,
123- ".png" ,
124- prompt = config .prompt ,
125- seed = seed ,
126- resolution = (height , width , config .num_frames ),
127- dir = output_dir ,
128- )
129- imageio .imwrite (output_filename , video_np [0 ])
130- else :
131- output_filename = get_unique_filename (
132- f"video_output_{ i } " ,
133- ".mp4" ,
134- prompt = config .prompt ,
135- seed = seed ,
136- resolution = (height , width , config .num_frames ),
137- dir = output_dir ,
138- )
139- print (output_filename )
140- # Write video
141- with imageio .get_writer (output_filename , fps = fps ) as video :
142- for frame in video_np :
143- video .append_data (frame )
79+ height_padded = ((config .height - 1 ) // 32 + 1 ) * 32
80+ width_padded = ((config .width - 1 ) // 32 + 1 ) * 32
81+ num_frames_padded = ((config .num_frames - 2 ) // 8 + 1 ) * 8 + 1
82+ padding = calculate_padding (config .height , config .width , height_padded , width_padded )
83+
84+ seed = 10
85+ generator = torch .Generator ().manual_seed (seed )
86+ pipeline = LTXVideoPipeline .from_pretrained (config , enhance_prompt = False )
87+ pipeline = LTXMultiScalePipeline (pipeline )
88+ images = pipeline (
89+ height = height_padded ,
90+ width = width_padded ,
91+ num_frames = num_frames_padded ,
92+ output_type = "pt" ,
93+ generator = generator ,
94+ config = config ,
95+ )
96+ (pad_left , pad_right , pad_top , pad_bottom ) = padding
97+ pad_bottom = - pad_bottom
98+ pad_right = - pad_right
99+ if pad_bottom == 0 :
100+ pad_bottom = images .shape [3 ]
101+ if pad_right == 0 :
102+ pad_right = images .shape [4 ]
103+ images = images [:, :, : config .num_frames , pad_top :pad_bottom , pad_left :pad_right ]
104+ output_dir = Path (f"outputs/{ datetime .today ().strftime ('%Y-%m-%d' )} " )
105+ output_dir .mkdir (parents = True , exist_ok = True )
106+ for i in range (images .shape [0 ]):
107+ # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
108+ video_np = images [i ].permute (1 , 2 , 3 , 0 ).detach ().float ().numpy ()
109+ # Unnormalizing images to [0, 255] range
110+ video_np = (video_np * 255 ).astype (np .uint8 )
111+ fps = config .frame_rate
112+ height , width = video_np .shape [1 :3 ]
113+ # In case a single image is generated
114+ if video_np .shape [0 ] == 1 :
115+ output_filename = get_unique_filename (
116+ f"image_output_{ i } " ,
117+ ".png" ,
118+ prompt = config .prompt ,
119+ seed = seed ,
120+ resolution = (height , width , config .num_frames ),
121+ dir = output_dir ,
122+ )
123+ imageio .imwrite (output_filename , video_np [0 ])
124+ else :
125+ output_filename = get_unique_filename (
126+ f"video_output_{ i } " ,
127+ ".mp4" ,
128+ prompt = config .prompt ,
129+ seed = seed ,
130+ resolution = (height , width , config .num_frames ),
131+ dir = output_dir ,
132+ )
133+ print (output_filename )
134+ # Write video
135+ with imageio .get_writer (output_filename , fps = fps ) as video :
136+ for frame in video_np :
137+ video .append_data (frame )
144138
145139
146140def main (argv : Sequence [str ]) -> None :
147- pyconfig .initialize (argv )
148- run (pyconfig .config )
141+ pyconfig .initialize (argv )
142+ run (pyconfig .config )
149143
150144
151145if __name__ == "__main__" :
152- app .run (main )
146+ app .run (main )
0 commit comments