Skip to content

Commit c375471

Browse files
committed
multiscale pipeline
1 parent 972e316 commit c375471

25 files changed

Lines changed: 2426 additions & 6870 deletions

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def load_state_if_possible(
217217
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
218218
else:
219219
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
220-
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
220+
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
221221

222222
def map_to_pspec(data):
223223
pspec = data.sharding.spec

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ sampler: "from_checkpoint"
2424

2525
# Generation parameters
2626
pipeline_type: multi-scale
27-
prompt: "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage."
27+
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
2828
height: 512
2929
width: 512
3030
num_frames: 88 #344

src/maxdiffusion/generate_ltx_video.py

Lines changed: 98 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
55
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline
66
from maxdiffusion import pyconfig
7-
from maxdiffusion.models.ltx_video.autoencoders.latent_upsampler import LatentUpsampler
8-
from huggingface_hub import hf_hub_download
97
import imageio
108
from datetime import datetime
11-
129
import os
1310
import torch
1411
from pathlib import Path
@@ -18,52 +15,45 @@ def calculate_padding(
1815
source_height: int, source_width: int, target_height: int, target_width: int
1916
) -> tuple[int, int, int, int]:
2017

21-
# Calculate total padding needed
22-
pad_height = target_height - source_height
23-
pad_width = target_width - source_width
18+
# Calculate total padding needed
19+
pad_height = target_height - source_height
20+
pad_width = target_width - source_width
2421

25-
# Calculate padding for each side
26-
pad_top = pad_height // 2
27-
pad_bottom = pad_height - pad_top # Handles odd padding
28-
pad_left = pad_width // 2
29-
pad_right = pad_width - pad_left # Handles odd padding
22+
# Calculate padding for each side
23+
pad_top = pad_height // 2
24+
pad_bottom = pad_height - pad_top # Handles odd padding
25+
pad_left = pad_width // 2
26+
pad_right = pad_width - pad_left # Handles odd padding
3027

31-
# Return padded tensor
32-
# Padding format is (left, right, top, bottom)
33-
padding = (pad_left, pad_right, pad_top, pad_bottom)
34-
return padding
28+
# Return padded tensor
29+
# Padding format is (left, right, top, bottom)
30+
padding = (pad_left, pad_right, pad_top, pad_bottom)
31+
return padding
3532

3633

3734
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
38-
# Remove non-letters and convert to lowercase
39-
clean_text = "".join(
40-
char.lower() for char in text if char.isalpha() or char.isspace()
41-
)
35+
# Remove non-letters and convert to lowercase
36+
clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace())
4237

43-
# Split into words
44-
words = clean_text.split()
38+
# Split into words
39+
words = clean_text.split()
4540

46-
# Build result string keeping track of length
47-
result = []
48-
current_length = 0
41+
# Build result string keeping track of length
42+
result = []
43+
current_length = 0
4944

50-
for word in words:
51-
# Add word length plus 1 for underscore (except for first word)
52-
new_length = current_length + len(word)
45+
for word in words:
46+
# Add word length plus 1 for underscore (except for first word)
47+
new_length = current_length + len(word)
5348

54-
if new_length <= max_len:
55-
result.append(word)
56-
current_length += len(word)
57-
else:
58-
break
49+
if new_length <= max_len:
50+
result.append(word)
51+
current_length += len(word)
52+
else:
53+
break
5954

60-
return "-".join(result)
55+
return "-".join(result)
6156

62-
def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
63-
latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
64-
latent_upsampler.to(device)
65-
latent_upsampler.eval()
66-
return latent_upsampler
6757

6858
def get_unique_filename(
6959
base: str,
@@ -75,78 +65,82 @@ def get_unique_filename(
7565
endswith=None,
7666
index_range=1000,
7767
) -> Path:
78-
base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
79-
for i in range(index_range):
80-
filename = dir / \
81-
f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
82-
if not os.path.exists(filename):
83-
return filename
84-
raise FileExistsError(
85-
f"Could not find a unique filename after {index_range} attempts."
86-
)
68+
base_filename = (
69+
f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
70+
)
71+
for i in range(index_range):
72+
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
73+
if not os.path.exists(filename):
74+
return filename
75+
raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.")
8776

8877

8978
def run(config):
90-
height_padded = ((config.height - 1) // 32 + 1) * 32
91-
width_padded = ((config.width - 1) // 32 + 1) * 32
92-
num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
93-
padding = calculate_padding(
94-
config.height, config.width, height_padded, width_padded)
95-
96-
seed = 10
97-
generator = torch.Generator().manual_seed(seed)
98-
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt = False)
99-
pipeline = LTXMultiScalePipeline(pipeline)
100-
images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, output_type='pt', generator=generator, config = config)
101-
(pad_left, pad_right, pad_top, pad_bottom) = padding
102-
pad_bottom = -pad_bottom
103-
pad_right = -pad_right
104-
if pad_bottom == 0:
105-
pad_bottom = images.shape[3]
106-
if pad_right == 0:
107-
pad_right = images.shape[4]
108-
images = images[:, :, :config.num_frames,
109-
pad_top:pad_bottom, pad_left:pad_right]
110-
output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
111-
output_dir.mkdir(parents=True, exist_ok=True)
112-
for i in range(images.shape[0]):
113-
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
114-
video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy()
115-
# Unnormalizing images to [0, 255] range
116-
video_np = (video_np * 255).astype(np.uint8)
117-
fps = config.frame_rate
118-
height, width = video_np.shape[1:3]
119-
# In case a single image is generated
120-
if video_np.shape[0] == 1:
121-
output_filename = get_unique_filename(
122-
f"image_output_{i}",
123-
".png",
124-
prompt=config.prompt,
125-
seed=seed,
126-
resolution=(height, width, config.num_frames),
127-
dir=output_dir,
128-
)
129-
imageio.imwrite(output_filename, video_np[0])
130-
else:
131-
output_filename = get_unique_filename(
132-
f"video_output_{i}",
133-
".mp4",
134-
prompt=config.prompt,
135-
seed=seed,
136-
resolution=(height, width, config.num_frames),
137-
dir=output_dir,
138-
)
139-
print(output_filename)
140-
# Write video
141-
with imageio.get_writer(output_filename, fps=fps) as video:
142-
for frame in video_np:
143-
video.append_data(frame)
79+
height_padded = ((config.height - 1) // 32 + 1) * 32
80+
width_padded = ((config.width - 1) // 32 + 1) * 32
81+
num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
82+
padding = calculate_padding(config.height, config.width, height_padded, width_padded)
83+
84+
seed = 10
85+
generator = torch.Generator().manual_seed(seed)
86+
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt=False)
87+
pipeline = LTXMultiScalePipeline(pipeline)
88+
images = pipeline(
89+
height=height_padded,
90+
width=width_padded,
91+
num_frames=num_frames_padded,
92+
output_type="pt",
93+
generator=generator,
94+
config=config,
95+
)
96+
(pad_left, pad_right, pad_top, pad_bottom) = padding
97+
pad_bottom = -pad_bottom
98+
pad_right = -pad_right
99+
if pad_bottom == 0:
100+
pad_bottom = images.shape[3]
101+
if pad_right == 0:
102+
pad_right = images.shape[4]
103+
images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right]
104+
output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
105+
output_dir.mkdir(parents=True, exist_ok=True)
106+
for i in range(images.shape[0]):
107+
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
108+
video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy()
109+
# Unnormalizing images to [0, 255] range
110+
video_np = (video_np * 255).astype(np.uint8)
111+
fps = config.frame_rate
112+
height, width = video_np.shape[1:3]
113+
# In case a single image is generated
114+
if video_np.shape[0] == 1:
115+
output_filename = get_unique_filename(
116+
f"image_output_{i}",
117+
".png",
118+
prompt=config.prompt,
119+
seed=seed,
120+
resolution=(height, width, config.num_frames),
121+
dir=output_dir,
122+
)
123+
imageio.imwrite(output_filename, video_np[0])
124+
else:
125+
output_filename = get_unique_filename(
126+
f"video_output_{i}",
127+
".mp4",
128+
prompt=config.prompt,
129+
seed=seed,
130+
resolution=(height, width, config.num_frames),
131+
dir=output_dir,
132+
)
133+
print(output_filename)
134+
# Write video
135+
with imageio.get_writer(output_filename, fps=fps) as video:
136+
for frame in video_np:
137+
video.append_data(frame)
144138

145139

146140
def main(argv: Sequence[str]) -> None:
147-
pyconfig.initialize(argv)
148-
run(pyconfig.config)
141+
pyconfig.initialize(argv)
142+
run(pyconfig.config)
149143

150144

151145
if __name__ == "__main__":
152-
app.run(main)
146+
app.run(main)

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,4 +612,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
612612
initialize_jax_for_gpu()
613613
max_logging.log("Jax distributed system initialized on GPU!")
614614
else:
615-
jax.distributed.initialize()
615+
jax.distributed.initialize()

src/maxdiffusion/models/ltx_video/autoencoders/causal_conv3d.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

0 commit comments

Comments
 (0)