-
Notifications
You must be signed in to change notification settings - Fork 69
LTXVid text2vid pipeline #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 50 commits
3776190
13656fb
7bed4f9
b31a97b
7e098c5
9a9f5db
d1c304d
f93c3bd
e0327e5
c369302
991a44e
b0e9bab
e18128c
1c55452
fd4af91
5e17a62
fc60b27
3243535
e873a17
d06dee3
1ea6590
aa7befd
d9a3502
a1ad421
615174f
7469c62
6de4424
18ec247
2737877
8a043f6
35a3337
546ecab
12a247f
1062c72
535c75e
f6115df
8bf24a3
0f8483e
eaa7196
7af151a
634591b
a272d08
4bcffd1
f5afa91
e805034
bb61ecb
7d4b2a9
972e316
c375471
f63a6fa
b3874f5
3e6499c
0b67a19
443243d
fefe18e
4bad196
b1e5b0c
fd9eb11
4c9be69
0577d3e
072982c
8042df0
0c48524
d4c6738
36242d2
8fc3626
774e2c4
b4bd96e
f23eeef
c18c0c6
cfe2c64
0229dd6
740d403
e34d47e
460f7db
0d7f68f
8df8dbd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| #hardware | ||
| hardware: 'tpu' | ||
| skip_jax_distributed_system: False | ||
|
|
||
| jax_cache_dir: '' | ||
| weights_dtype: 'bfloat16' | ||
| activations_dtype: 'bfloat16' | ||
|
|
||
|
|
||
| run_name: '' | ||
| output_dir: '' | ||
| save_config_to_gcs: False | ||
|
|
||
| #Checkpoints | ||
| text_encoder_model_name_or_path: "ariG23498/t5-v1-1-xxl-flax" | ||
| prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0" | ||
| prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct" | ||
| frame_rate: 30 | ||
| max_sequence_length: 512 | ||
| sampler: "from_checkpoint" | ||
|
|
||
| # Generation parameters | ||
| pipeline_type: multi-scale | ||
| prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie. " | ||
| #negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" | ||
| height: 512 | ||
| width: 512 | ||
| num_frames: 88 | ||
| flow_shift: 5.0 | ||
| downscale_factor: 0.6666666 | ||
| spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors" | ||
| prompt_enhancement_words_threshold: 120 | ||
| stg_mode: "attention_values" | ||
| decode_timestep: 0.05 | ||
| decode_noise_scale: 0.025 | ||
| seed: 10 | ||
|
|
||
|
|
||
| first_pass: | ||
| guidance_scale: [1, 1, 6, 8, 6, 1, 1] | ||
| stg_scale: [0, 0, 4, 4, 4, 2, 1] | ||
| rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1] | ||
| guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180] | ||
| skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]] | ||
| num_inference_steps: 30 | ||
| skip_final_inference_steps: 3 | ||
| skip_initial_inference_steps: 0 | ||
| cfg_star_rescale: True | ||
|
|
||
| second_pass: | ||
| guidance_scale: [1] | ||
| stg_scale: [1] | ||
| rescaling_scale: [1] | ||
| guidance_timesteps: [1.0] | ||
| skip_block_list: [27] | ||
| num_inference_steps: 30 | ||
| skip_initial_inference_steps: 17 | ||
| skip_final_inference_steps: 0 | ||
| cfg_star_rescale: True | ||
|
|
||
| #parallelism | ||
| mesh_axes: ['data', 'fsdp', 'tensor'] | ||
| logical_axis_rules: [ | ||
| ['batch', 'data'], | ||
| ['activation_heads', 'fsdp'], | ||
| ['activation_batch', 'data'], | ||
| ['activation_kv', 'tensor'], | ||
| ['mlp','tensor'], | ||
| ['embed','fsdp'], | ||
| ['heads', 'tensor'], | ||
| ['norm', 'fsdp'], | ||
| ['conv_batch', ['data','fsdp']], | ||
| ['out_channels', 'tensor'], | ||
| ['conv_out', 'fsdp'], | ||
| ['conv_in', 'fsdp'] | ||
| ] | ||
| data_sharding: [['data', 'fsdp', 'tensor']] | ||
| dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded | ||
| dcn_fsdp_parallelism: -1 | ||
| dcn_tensor_parallelism: 1 | ||
| ici_data_parallelism: 1 | ||
| ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded | ||
| ici_tensor_parallelism: 1 | ||
|
|
||
|
|
||
|
|
||
| learning_rate_schedule_steps: -1 | ||
| max_train_steps: 500 #TODO: change this | ||
| pretrained_model_name_or_path: '' | ||
| unet_checkpoint: '' | ||
| dataset_name: 'diffusers/pokemon-gpt4-captions' | ||
| train_split: 'train' | ||
| dataset_type: 'tf' | ||
| cache_latents_text_encoder_outputs: True | ||
| per_device_batch_size: 1 | ||
| compile_topology_num_slices: -1 | ||
| quantization_local_shard_count: -1 | ||
| jit_initializers: True | ||
| enable_single_replica_ckpt_restoring: False |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| """ | ||
| Copyright 2025 Google LLC | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| https://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| import numpy as np | ||
| from absl import app | ||
| from typing import Sequence | ||
| from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline | ||
| from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline | ||
| from maxdiffusion import pyconfig | ||
| import imageio | ||
| from datetime import datetime | ||
| import os | ||
| import time | ||
| from pathlib import Path | ||
|
|
||
|
|
||
| def calculate_padding( | ||
| source_height: int, source_width: int, target_height: int, target_width: int | ||
| ) -> tuple[int, int, int, int]: | ||
|
|
||
| # Calculate total padding needed | ||
| pad_height = target_height - source_height | ||
| pad_width = target_width - source_width | ||
|
|
||
| # Calculate padding for each side | ||
| pad_top = pad_height // 2 | ||
| pad_bottom = pad_height - pad_top # Handles odd padding | ||
| pad_left = pad_width // 2 | ||
| pad_right = pad_width - pad_left # Handles odd padding | ||
| padding = (pad_left, pad_right, pad_top, pad_bottom) | ||
| return padding | ||
|
|
||
|
|
||
| def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: | ||
| # Remove non-letters and convert to lowercase | ||
| clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace()) | ||
|
|
||
| # Split into words | ||
| words = clean_text.split() | ||
|
|
||
| # Build result string keeping track of length | ||
| result = [] | ||
| current_length = 0 | ||
|
|
||
| for word in words: | ||
| # Add word length plus 1 for underscore (except for first word) | ||
| new_length = current_length + len(word) | ||
|
|
||
| if new_length <= max_len: | ||
| result.append(word) | ||
| current_length += len(word) | ||
| else: | ||
| break | ||
|
|
||
| return "-".join(result) | ||
|
|
||
|
|
||
| def get_unique_filename( | ||
| base: str, | ||
| ext: str, | ||
| prompt: str, | ||
| resolution: tuple[int, int, int], | ||
| dir: Path, | ||
| endswith=None, | ||
| index_range=1000, | ||
| ) -> Path: | ||
| base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{resolution[0]}x{resolution[1]}x{resolution[2]}" | ||
| for i in range(index_range): | ||
| filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" | ||
| if not os.path.exists(filename): | ||
| return filename | ||
| raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.") | ||
|
|
||
|
|
||
| def run(config): | ||
| height_padded = ((config.height - 1) // 32 + 1) * 32 | ||
| width_padded = ((config.width - 1) // 32 + 1) * 32 | ||
| num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 | ||
| padding = calculate_padding(config.height, config.width, height_padded, width_padded) | ||
| prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold | ||
| prompt_word_count = len(config.prompt.split()) | ||
| enhance_prompt = prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold | ||
|
|
||
| pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=enhance_prompt) | ||
| if config.pipeline_type == "multi-scale": | ||
| pipeline = LTXMultiScalePipeline(pipeline) | ||
| # s0 = time.perf_counter() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove commented out lines.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to keep the time benchmarking in the code? |
||
| # images = pipeline( | ||
| # height=height_padded, | ||
| # width=width_padded, | ||
| # num_frames=num_frames_padded, | ||
| # is_video=True, | ||
| # output_type="pt", | ||
| # config=config, | ||
| # enhance_prompt=enhance_prompt, | ||
| # seed = config.seed | ||
| # ) | ||
| # print("compile time: ", (time.perf_counter() - s0)) | ||
| s0 = time.perf_counter() | ||
| images = pipeline( | ||
| height=height_padded, | ||
| width=width_padded, | ||
| num_frames=num_frames_padded, | ||
| is_video=True, | ||
| output_type="pt", | ||
| config=config, | ||
| enhance_prompt=enhance_prompt, | ||
| seed=config.seed, | ||
| ) | ||
| print("generation time: ", (time.perf_counter() - s0)) | ||
|
|
||
| (pad_left, pad_right, pad_top, pad_bottom) = padding | ||
| pad_bottom = -pad_bottom | ||
| pad_right = -pad_right | ||
| if pad_bottom == 0: | ||
| pad_bottom = images.shape[3] | ||
| if pad_right == 0: | ||
| pad_right = images.shape[4] | ||
| images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right] | ||
| output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") | ||
| output_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| for i in range(images.shape[0]): | ||
| # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C | ||
| video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy() | ||
| # Unnormalizing images to [0, 255] range | ||
| video_np = (video_np * 255).astype(np.uint8) | ||
| fps = config.frame_rate | ||
| height, width = video_np.shape[1:3] | ||
| # In case a single image is generated | ||
| if video_np.shape[0] == 1: | ||
| output_filename = get_unique_filename( | ||
| f"image_output_{i}", | ||
| ".png", | ||
| prompt=config.prompt, | ||
| resolution=(height, width, config.num_frames), | ||
| dir=output_dir, | ||
| ) | ||
| imageio.imwrite(output_filename, video_np[0]) | ||
| else: | ||
| output_filename = get_unique_filename( | ||
| f"video_output_{i}", | ||
| ".mp4", | ||
| prompt=config.prompt, | ||
| resolution=(height, width, config.num_frames), | ||
| dir=output_dir, | ||
| ) | ||
| print(output_filename) | ||
| # Write video | ||
| with imageio.get_writer(output_filename, fps=fps) as video: | ||
| for frame in video_np: | ||
| video.append_data(frame) | ||
|
|
||
|
|
||
| def main(argv: Sequence[str]) -> None: | ||
| pyconfig.initialize(argv) | ||
| run(pyconfig.config) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| app.run(main) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -402,7 +402,10 @@ def setup_initial_state( | |
| config.enable_single_replica_ckpt_restoring, | ||
| ) | ||
| if state: | ||
| state = state[checkpoint_item] | ||
| if checkpoint_item == " ": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is checkpoint_item checking against |
||
| state = state | ||
| else: | ||
| state = state[checkpoint_item] | ||
| if not state: | ||
| max_logging.log(f"Could not find the item in orbax, creating state...") | ||
| init_train_state_partial = functools.partial( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| """ | ||
| Copyright 2025 Google LLC | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| https://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # Copyright 2025 Lightricks Ltd. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://github.com/Lightricks/LTX-Video/blob/main/LICENSE | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
| # This implementation is based on the Torch version available at: | ||
| # https://github.com/Lightricks/LTX-Video/tree/main |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be if checkpoint_item is None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if checkpoint set to None, cannot pass the check "if checkpoint_manager and checkpoint_item:" in max_utils.py. So I set it to empty string to get around this