Skip to content

Commit 169a847

Browse files
committed
ruff check fixed
1 parent 41c342f commit 169a847

36 files changed

Lines changed: 5878 additions & 8318 deletions

src/maxdiffusion/checkpointing/checkpointing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def load_state_if_possible(
217217
return checkpoint_manager.restore(latest_step, args=ocp.args.StandardRestore(abstract_unboxed_pre_state))
218218
else:
219219
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)}
220-
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
220+
return checkpoint_manager.restore(latest_step, args=orbax.checkpoint.args.Composite(**item))
221221

222222
def map_to_pspec(data):
223223
pspec = data.sharding.spec

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ jax_cache_dir: ''
66
weights_dtype: 'bfloat16'
77
activations_dtype: 'bfloat16'
88

9-
109
run_name: ''
1110
output_dir: 'ltx-video-output'
1211
save_config_to_gcs: False
@@ -21,9 +20,9 @@ frame_rate: 30
2120

2221

2322
# Generation parameters
23+
ckpt_path: "/mnt/disks/diffusionproj"
2424
prompt: "A man in a dimly lit room talks on a vintage telephone, hangs up, and looks down with a sad expression. He holds the black rotary phone to his right ear with his right hand, his left hand holding a rocks glass with amber liquid. He wears a brown suit jacket over a white shirt, and a gold ring on his left ring finger. His short hair is neatly combed, and he has light skin with visible wrinkles around his eyes. The camera remains stationary, focused on his face and upper body. The room is dark, lit only by a warm light source off-screen to the left, casting shadows on the wall behind him. The scene appears to be from a movie."
2525
#negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
26-
#do_classifier_free_guidance: True
2726
height: 512
2827
width: 512
2928
num_frames: 88
@@ -63,7 +62,7 @@ ici_sequence_parallelism: 1
6362

6463

6564
learning_rate_schedule_steps: -1
66-
max_train_steps: 500 #TODO: change this
65+
max_train_steps: 500
6766
pretrained_model_name_or_path: ''
6867
unet_checkpoint: ''
6968
dataset_name: 'diffusers/pokemon-gpt4-captions'

src/maxdiffusion/generate_ltx_video.py

Lines changed: 99 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -3,57 +3,57 @@
33
from typing import Sequence
44
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXVideoPipeline
55
from maxdiffusion import pyconfig
6-
import jax.numpy as jnp
76
import imageio
87
from datetime import datetime
98
import os
10-
import json
119
import torch
1210
from pathlib import Path
11+
12+
1313
def calculate_padding(
1414
source_height: int, source_width: int, target_height: int, target_width: int
1515
) -> tuple[int, int, int, int]:
1616

17-
# Calculate total padding needed
18-
pad_height = target_height - source_height
19-
pad_width = target_width - source_width
20-
21-
# Calculate padding for each side
22-
pad_top = pad_height // 2
23-
pad_bottom = pad_height - pad_top # Handles odd padding
24-
pad_left = pad_width // 2
25-
pad_right = pad_width - pad_left # Handles odd padding
26-
27-
# Return padded tensor
28-
# Padding format is (left, right, top, bottom)
29-
padding = (pad_left, pad_right, pad_top, pad_bottom)
30-
return padding
31-
17+
# Calculate total padding needed
18+
pad_height = target_height - source_height
19+
pad_width = target_width - source_width
20+
21+
# Calculate padding for each side
22+
pad_top = pad_height // 2
23+
pad_bottom = pad_height - pad_top # Handles odd padding
24+
pad_left = pad_width // 2
25+
pad_right = pad_width - pad_left # Handles odd padding
26+
27+
# Return padded tensor
28+
# Padding format is (left, right, top, bottom)
29+
padding = (pad_left, pad_right, pad_top, pad_bottom)
30+
return padding
31+
32+
3233
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
33-
# Remove non-letters and convert to lowercase
34-
clean_text = "".join(
35-
char.lower() for char in text if char.isalpha() or char.isspace()
36-
)
37-
38-
# Split into words
39-
words = clean_text.split()
40-
41-
# Build result string keeping track of length
42-
result = []
43-
current_length = 0
44-
45-
for word in words:
46-
# Add word length plus 1 for underscore (except for first word)
47-
new_length = current_length + len(word)
48-
49-
if new_length <= max_len:
50-
result.append(word)
51-
current_length += len(word)
52-
else:
53-
break
54-
55-
return "-".join(result)
56-
34+
# Remove non-letters and convert to lowercase
35+
clean_text = "".join(char.lower() for char in text if char.isalpha() or char.isspace())
36+
37+
# Split into words
38+
words = clean_text.split()
39+
40+
# Build result string keeping track of length
41+
result = []
42+
current_length = 0
43+
44+
for word in words:
45+
# Add word length plus 1 for underscore (except for first word)
46+
new_length = current_length + len(word)
47+
48+
if new_length <= max_len:
49+
result.append(word)
50+
current_length += len(word)
51+
else:
52+
break
53+
54+
return "-".join(result)
55+
56+
5757
def get_unique_filename(
5858
base: str,
5959
ext: str,
@@ -64,79 +64,80 @@ def get_unique_filename(
6464
endswith=None,
6565
index_range=1000,
6666
) -> Path:
67-
base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
68-
for i in range(index_range):
69-
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
70-
if not os.path.exists(filename):
71-
return filename
72-
raise FileExistsError(
73-
f"Could not find a unique filename after {index_range} attempts."
74-
)
67+
base_filename = (
68+
f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
69+
)
70+
for i in range(index_range):
71+
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
72+
if not os.path.exists(filename):
73+
return filename
74+
raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.")
75+
76+
7577
def run(config):
76-
78+
7779
height_padded = ((config.height - 1) // 32 + 1) * 32
7880
width_padded = ((config.width - 1) // 32 + 1) * 32
7981
num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
8082
padding = calculate_padding(config.height, config.width, height_padded, width_padded)
8183
prompt_enhancement_words_threshold = config.prompt_enhancement_words_threshold
8284
prompt_word_count = len(config.prompt.split())
83-
enhance_prompt = (
84-
prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold
85-
)
86-
87-
seed = 10 #change this, generator in pytorch, used in prepare_latents
85+
enhance_prompt = prompt_enhancement_words_threshold > 0 and prompt_word_count < prompt_enhancement_words_threshold
86+
87+
seed = 10
8888
generator = torch.Generator().manual_seed(seed)
8989
pipeline = LTXVideoPipeline.from_pretrained(config, enhance_prompt)
90-
images = pipeline(height=height_padded, width=width_padded, num_frames=num_frames_padded, is_video=True, output_type='pt', generator = generator).images
91-
90+
images = pipeline(
91+
height=height_padded,
92+
width=width_padded,
93+
num_frames=num_frames_padded,
94+
is_video=True,
95+
output_type="pt",
96+
generator=generator,
97+
).images
98+
9299
(pad_left, pad_right, pad_top, pad_bottom) = padding
93100
pad_bottom = -pad_bottom
94101
pad_right = -pad_right
95102
if pad_bottom == 0:
96-
pad_bottom = images.shape[3]
103+
pad_bottom = images.shape[3]
97104
if pad_right == 0:
98-
pad_right = images.shape[4]
99-
images = images[:, :, :config.num_frames, pad_top:pad_bottom, pad_left:pad_right]
105+
pad_right = images.shape[4]
106+
images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right]
100107
output_dir = Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
101108
output_dir.mkdir(parents=True, exist_ok=True)
102109
for i in range(images.shape[0]):
103-
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
104-
video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy()
105-
# Unnormalizing images to [0, 255] range
106-
video_np = (video_np * 255).astype(np.uint8)
107-
fps = config.frame_rate
108-
height, width = video_np.shape[1:3]
109-
# In case a single image is generated
110-
if video_np.shape[0] == 1:
111-
output_filename = get_unique_filename(
112-
f"image_output_{i}",
113-
".png",
114-
prompt=config.prompt,
115-
seed=seed,
116-
resolution=(height, width, config.num_frames),
117-
dir=output_dir,
118-
)
119-
imageio.imwrite(output_filename, video_np[0])
120-
else:
121-
output_filename = get_unique_filename(
122-
f"video_output_{i}",
123-
".mp4",
124-
prompt=config.prompt,
125-
seed=seed,
126-
resolution=(height, width, config.num_frames),
127-
dir=output_dir,
128-
)
129-
print(output_filename)
130-
# Write video
131-
with imageio.get_writer(output_filename, fps=fps) as video:
132-
for frame in video_np:
133-
video.append_data(frame)
134-
135-
136-
137-
138-
139-
110+
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
111+
video_np = images[i].permute(1, 2, 3, 0).detach().float().numpy()
112+
# Unnormalizing images to [0, 255] range
113+
video_np = (video_np * 255).astype(np.uint8)
114+
fps = config.frame_rate
115+
height, width = video_np.shape[1:3]
116+
# In case a single image is generated
117+
if video_np.shape[0] == 1:
118+
output_filename = get_unique_filename(
119+
f"image_output_{i}",
120+
".png",
121+
prompt=config.prompt,
122+
seed=seed,
123+
resolution=(height, width, config.num_frames),
124+
dir=output_dir,
125+
)
126+
imageio.imwrite(output_filename, video_np[0])
127+
else:
128+
output_filename = get_unique_filename(
129+
f"video_output_{i}",
130+
".mp4",
131+
prompt=config.prompt,
132+
seed=seed,
133+
resolution=(height, width, config.num_frames),
134+
dir=output_dir,
135+
)
136+
print(output_filename)
137+
# Write video
138+
with imageio.get_writer(output_filename, fps=fps) as video:
139+
for frame in video_np:
140+
video.append_data(frame)
140141

141142

142143
def main(argv: Sequence[str]) -> None:
@@ -145,4 +146,4 @@ def main(argv: Sequence[str]) -> None:
145146

146147

147148
if __name__ == "__main__":
148-
app.run(main)
149+
app.run(main)

src/maxdiffusion/max_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
251251

252252
return parallelism_vals
253253

254+
254255
def create_device_mesh(config, devices=None):
255256
"""Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas"""
256257
if devices is None:
@@ -269,6 +270,7 @@ def create_device_mesh(config, devices=None):
269270

270271
return mesh
271272

273+
272274
def unbox_logicallypartioned_trainstate(boxed_train_state: train_state.TrainState):
273275
"""Unboxes the flax.LogicallyPartitioned pieces in a train state.
274276
@@ -590,4 +592,4 @@ def maybe_initialize_jax_distributed_system(raw_keys):
590592
initialize_jax_for_gpu()
591593
max_logging.log("Jax distributed system initialized on GPU!")
592594
else:
593-
jax.distributed.initialize()
595+
jax.distributed.initialize()

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1188,4 +1188,4 @@ def setup(self):
11881188
def __call__(self, hidden_states, deterministic=True):
11891189
hidden_states = self.proj(hidden_states)
11901190
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
1191-
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
1191+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)

0 commit comments

Comments
 (0)