33from typing import Sequence
44from maxdiffusion .pipelines .ltx_video .ltx_video_pipeline import LTXVideoPipeline
55from maxdiffusion import pyconfig
6- import jax .numpy as jnp
76import imageio
87from datetime import datetime
98import os
10- import json
119import torch
1210from pathlib import Path
11+
12+
1313def calculate_padding (
1414 source_height : int , source_width : int , target_height : int , target_width : int
1515) -> tuple [int , int , int , int ]:
1616
17- # Calculate total padding needed
18- pad_height = target_height - source_height
19- pad_width = target_width - source_width
20-
21- # Calculate padding for each side
22- pad_top = pad_height // 2
23- pad_bottom = pad_height - pad_top # Handles odd padding
24- pad_left = pad_width // 2
25- pad_right = pad_width - pad_left # Handles odd padding
26-
27- # Return padded tensor
28- # Padding format is (left, right, top, bottom)
29- padding = (pad_left , pad_right , pad_top , pad_bottom )
30- return padding
31-
17+ # Calculate total padding needed
18+ pad_height = target_height - source_height
19+ pad_width = target_width - source_width
20+
21+ # Calculate padding for each side
22+ pad_top = pad_height // 2
23+ pad_bottom = pad_height - pad_top # Handles odd padding
24+ pad_left = pad_width // 2
25+ pad_right = pad_width - pad_left # Handles odd padding
26+
27+ # Return padded tensor
28+ # Padding format is (left, right, top, bottom)
29+ padding = (pad_left , pad_right , pad_top , pad_bottom )
30+ return padding
31+
32+
3233def convert_prompt_to_filename (text : str , max_len : int = 20 ) -> str :
33- # Remove non-letters and convert to lowercase
34- clean_text = "" .join (
35- char .lower () for char in text if char .isalpha () or char .isspace ()
36- )
37-
38- # Split into words
39- words = clean_text .split ()
40-
41- # Build result string keeping track of length
42- result = []
43- current_length = 0
44-
45- for word in words :
46- # Add word length plus 1 for underscore (except for first word)
47- new_length = current_length + len (word )
48-
49- if new_length <= max_len :
50- result .append (word )
51- current_length += len (word )
52- else :
53- break
54-
55- return "-" .join (result )
56-
34+ # Remove non-letters and convert to lowercase
35+ clean_text = "" .join (char .lower () for char in text if char .isalpha () or char .isspace ())
36+
37+ # Split into words
38+ words = clean_text .split ()
39+
40+ # Build result string keeping track of length
41+ result = []
42+ current_length = 0
43+
44+ for word in words :
45+ # Add word length plus 1 for underscore (except for first word)
46+ new_length = current_length + len (word )
47+
48+ if new_length <= max_len :
49+ result .append (word )
50+ current_length += len (word )
51+ else :
52+ break
53+
54+ return "-" .join (result )
55+
56+
5757def get_unique_filename (
5858 base : str ,
5959 ext : str ,
@@ -64,79 +64,80 @@ def get_unique_filename(
6464 endswith = None ,
6565 index_range = 1000 ,
6666) -> Path :
67- base_filename = f"{ base } _{ convert_prompt_to_filename (prompt , max_len = 30 )} _{ seed } _{ resolution [0 ]} x{ resolution [1 ]} x{ resolution [2 ]} "
68- for i in range (index_range ):
69- filename = dir / f"{ base_filename } _{ i } { endswith if endswith else '' } { ext } "
70- if not os .path .exists (filename ):
71- return filename
72- raise FileExistsError (
73- f"Could not find a unique filename after { index_range } attempts."
74- )
67+ base_filename = (
68+ f"{ base } _{ convert_prompt_to_filename (prompt , max_len = 30 )} _{ seed } _{ resolution [0 ]} x{ resolution [1 ]} x{ resolution [2 ]} "
69+ )
70+ for i in range (index_range ):
71+ filename = dir / f"{ base_filename } _{ i } { endswith if endswith else '' } { ext } "
72+ if not os .path .exists (filename ):
73+ return filename
74+ raise FileExistsError (f"Could not find a unique filename after { index_range } attempts." )
75+
76+
7577def run (config ):
76-
78+
7779 height_padded = ((config .height - 1 ) // 32 + 1 ) * 32
7880 width_padded = ((config .width - 1 ) // 32 + 1 ) * 32
7981 num_frames_padded = ((config .num_frames - 2 ) // 8 + 1 ) * 8 + 1
8082 padding = calculate_padding (config .height , config .width , height_padded , width_padded )
8183 prompt_enhancement_words_threshold = config .prompt_enhancement_words_threshold
8284 prompt_word_count = len (config .prompt .split ())
83- enhance_prompt = (
84- prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold
85- )
86-
87- seed = 10 #change this, generator in pytorch, used in prepare_latents
85+ enhance_prompt = prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold
86+
87+ seed = 10
8888 generator = torch .Generator ().manual_seed (seed )
8989 pipeline = LTXVideoPipeline .from_pretrained (config , enhance_prompt )
90- images = pipeline (height = height_padded , width = width_padded , num_frames = num_frames_padded , is_video = True , output_type = 'pt' , generator = generator ).images
91-
90+ images = pipeline (
91+ height = height_padded ,
92+ width = width_padded ,
93+ num_frames = num_frames_padded ,
94+ is_video = True ,
95+ output_type = "pt" ,
96+ generator = generator ,
97+ ).images
98+
9299 (pad_left , pad_right , pad_top , pad_bottom ) = padding
93100 pad_bottom = - pad_bottom
94101 pad_right = - pad_right
95102 if pad_bottom == 0 :
96- pad_bottom = images .shape [3 ]
103+ pad_bottom = images .shape [3 ]
97104 if pad_right == 0 :
98- pad_right = images .shape [4 ]
99- images = images [:, :, :config .num_frames , pad_top :pad_bottom , pad_left :pad_right ]
105+ pad_right = images .shape [4 ]
106+ images = images [:, :, : config .num_frames , pad_top :pad_bottom , pad_left :pad_right ]
100107 output_dir = Path (f"outputs/{ datetime .today ().strftime ('%Y-%m-%d' )} " )
101108 output_dir .mkdir (parents = True , exist_ok = True )
102109 for i in range (images .shape [0 ]):
103- # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
104- video_np = images [i ].permute (1 , 2 , 3 , 0 ).detach ().float ().numpy ()
105- # Unnormalizing images to [0, 255] range
106- video_np = (video_np * 255 ).astype (np .uint8 )
107- fps = config .frame_rate
108- height , width = video_np .shape [1 :3 ]
109- # In case a single image is generated
110- if video_np .shape [0 ] == 1 :
111- output_filename = get_unique_filename (
112- f"image_output_{ i } " ,
113- ".png" ,
114- prompt = config .prompt ,
115- seed = seed ,
116- resolution = (height , width , config .num_frames ),
117- dir = output_dir ,
118- )
119- imageio .imwrite (output_filename , video_np [0 ])
120- else :
121- output_filename = get_unique_filename (
122- f"video_output_{ i } " ,
123- ".mp4" ,
124- prompt = config .prompt ,
125- seed = seed ,
126- resolution = (height , width , config .num_frames ),
127- dir = output_dir ,
128- )
129- print (output_filename )
130- # Write video
131- with imageio .get_writer (output_filename , fps = fps ) as video :
132- for frame in video_np :
133- video .append_data (frame )
134-
135-
136-
137-
138-
139-
110+ # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
111+ video_np = images [i ].permute (1 , 2 , 3 , 0 ).detach ().float ().numpy ()
112+ # Unnormalizing images to [0, 255] range
113+ video_np = (video_np * 255 ).astype (np .uint8 )
114+ fps = config .frame_rate
115+ height , width = video_np .shape [1 :3 ]
116+ # In case a single image is generated
117+ if video_np .shape [0 ] == 1 :
118+ output_filename = get_unique_filename (
119+ f"image_output_{ i } " ,
120+ ".png" ,
121+ prompt = config .prompt ,
122+ seed = seed ,
123+ resolution = (height , width , config .num_frames ),
124+ dir = output_dir ,
125+ )
126+ imageio .imwrite (output_filename , video_np [0 ])
127+ else :
128+ output_filename = get_unique_filename (
129+ f"video_output_{ i } " ,
130+ ".mp4" ,
131+ prompt = config .prompt ,
132+ seed = seed ,
133+ resolution = (height , width , config .num_frames ),
134+ dir = output_dir ,
135+ )
136+ print (output_filename )
137+ # Write video
138+ with imageio .get_writer (output_filename , fps = fps ) as video :
139+ for frame in video_np :
140+ video .append_data (frame )
140141
141142
142143def main (argv : Sequence [str ]) -> None :
@@ -145,4 +146,4 @@ def main(argv: Sequence[str]) -> None:
145146
146147
147148if __name__ == "__main__" :
148- app .run (main )
149+ app .run (main )
0 commit comments