Skip to content

Commit cd16f28

Browse files
run linter
1 parent cf68754 commit cd16f28

11 files changed

Lines changed: 877 additions & 961 deletions

src/maxdiffusion/generate_wan.py

Lines changed: 116 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import html
1718
from typing import Callable, List, Union, Sequence, Optional
1819
import time
@@ -37,21 +38,23 @@
3738
setup_initial_state,
3839
)
3940

41+
4042
def basic_clean(text):
41-
text = ftfy.fix_text(text)
42-
text = html.unescape(html.unescape(text))
43-
return text.strip()
43+
text = ftfy.fix_text(text)
44+
text = html.unescape(html.unescape(text))
45+
return text.strip()
4446

4547

4648
def whitespace_clean(text):
47-
text = re.sub(r"\s+", " ", text)
48-
text = text.strip()
49-
return text
49+
text = re.sub(r"\s+", " ", text)
50+
text = text.strip()
51+
return text
5052

5153

5254
def prompt_clean(text):
53-
text = whitespace_clean(basic_clean(text))
54-
return text
55+
text = whitespace_clean(basic_clean(text))
56+
return text
57+
5558

5659
def _get_t5_prompt_embeds(
5760
tokenizer: AutoTokenizer,
@@ -63,35 +66,36 @@ def _get_t5_prompt_embeds(
6366
dtype: Optional[torch.dtype] = None,
6467
):
6568

66-
prompt = [prompt] if isinstance(prompt, str) else prompt
67-
prompt = [prompt_clean(u) for u in prompt]
68-
batch_size = len(prompt)
69+
prompt = [prompt] if isinstance(prompt, str) else prompt
70+
prompt = [prompt_clean(u) for u in prompt]
71+
batch_size = len(prompt)
72+
73+
text_inputs = tokenizer(
74+
prompt,
75+
padding="max_length",
76+
max_length=max_sequence_length,
77+
truncation=True,
78+
add_special_tokens=True,
79+
return_attention_mask=True,
80+
return_tensors="pt",
81+
)
82+
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
83+
seq_lens = mask.gt(0).sum(dim=1).long()
84+
85+
prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
86+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
87+
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
88+
prompt_embeds = torch.stack(
89+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
90+
)
6991

70-
text_inputs = tokenizer(
71-
prompt,
72-
padding="max_length",
73-
max_length=max_sequence_length,
74-
truncation=True,
75-
add_special_tokens=True,
76-
return_attention_mask=True,
77-
return_tensors="pt",
78-
)
79-
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
80-
seq_lens = mask.gt(0).sum(dim=1).long()
81-
82-
prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
83-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
84-
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
85-
prompt_embeds = torch.stack(
86-
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
87-
)
92+
# duplicate text embeddings for each generation per prompt, using mps friendly method
93+
_, seq_len, _ = prompt_embeds.shape
94+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
95+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
8896

89-
# duplicate text embeddings for each generation per prompt, using mps friendly method
90-
_, seq_len, _ = prompt_embeds.shape
91-
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
92-
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
97+
return prompt_embeds
9398

94-
return prompt_embeds
9599

96100
def encode_prompt(
97101
tokenizer: AutoTokenizer,
@@ -106,77 +110,77 @@ def encode_prompt(
106110
device: Optional[torch.device] = None,
107111
dtype: Optional[torch.dtype] = None,
108112
):
109-
r"""
110-
Encodes the prompt into text encoder hidden states.
111-
112-
Args:
113-
prompt (`str` or `List[str]`, *optional*):
114-
prompt to be encoded
115-
negative_prompt (`str` or `List[str]`, *optional*):
116-
The prompt or prompts not to guide the image generation. If not defined, one has to pass
117-
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
118-
less than `1`).
119-
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
120-
Whether to use classifier free guidance or not.
121-
num_videos_per_prompt (`int`, *optional*, defaults to 1):
122-
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
123-
prompt_embeds (`torch.Tensor`, *optional*):
124-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
125-
provided, text embeddings will be generated from `prompt` input argument.
126-
negative_prompt_embeds (`torch.Tensor`, *optional*):
127-
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
128-
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
129-
argument.
130-
device: (`torch.device`, *optional*):
131-
torch device
132-
dtype: (`torch.dtype`, *optional*):
133-
torch dtype
134-
"""
135-
136-
prompt = [prompt] if isinstance(prompt, str) else prompt
137-
if prompt is not None:
138-
batch_size = len(prompt)
139-
else:
140-
batch_size = prompt_embeds.shape[0]
141-
142-
if prompt_embeds is None:
143-
prompt_embeds = _get_t5_prompt_embeds(
144-
tokenizer=tokenizer,
145-
text_encoder=text_encoder,
146-
prompt=prompt,
147-
num_videos_per_prompt=num_videos_per_prompt,
148-
max_sequence_length=max_sequence_length,
149-
device=device,
150-
dtype=dtype,
151-
)
152-
153-
if do_classifier_free_guidance and negative_prompt_embeds is None:
154-
negative_prompt = negative_prompt or ""
155-
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
156-
157-
if prompt is not None and type(prompt) is not type(negative_prompt):
158-
raise TypeError(
159-
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
160-
f" {type(prompt)}."
161-
)
162-
elif batch_size != len(negative_prompt):
163-
raise ValueError(
164-
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
165-
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
166-
" the batch size of `prompt`."
167-
)
168-
169-
negative_prompt_embeds = _get_t5_prompt_embeds(
170-
tokenizer=tokenizer,
171-
text_encoder=text_encoder,
172-
prompt=negative_prompt,
173-
num_videos_per_prompt=num_videos_per_prompt,
174-
max_sequence_length=max_sequence_length,
175-
device=device,
176-
dtype=dtype,
177-
)
178-
179-
return prompt_embeds, negative_prompt_embeds
113+
r"""
114+
Encodes the prompt into text encoder hidden states.
115+
116+
Args:
117+
prompt (`str` or `List[str]`, *optional*):
118+
prompt to be encoded
119+
negative_prompt (`str` or `List[str]`, *optional*):
120+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
121+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
122+
less than `1`).
123+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
124+
Whether to use classifier free guidance or not.
125+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
126+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
127+
prompt_embeds (`torch.Tensor`, *optional*):
128+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
129+
provided, text embeddings will be generated from `prompt` input argument.
130+
negative_prompt_embeds (`torch.Tensor`, *optional*):
131+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
132+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
133+
argument.
134+
device: (`torch.device`, *optional*):
135+
torch device
136+
dtype: (`torch.dtype`, *optional*):
137+
torch dtype
138+
"""
139+
140+
prompt = [prompt] if isinstance(prompt, str) else prompt
141+
if prompt is not None:
142+
batch_size = len(prompt)
143+
else:
144+
batch_size = prompt_embeds.shape[0]
145+
146+
if prompt_embeds is None:
147+
prompt_embeds = _get_t5_prompt_embeds(
148+
tokenizer=tokenizer,
149+
text_encoder=text_encoder,
150+
prompt=prompt,
151+
num_videos_per_prompt=num_videos_per_prompt,
152+
max_sequence_length=max_sequence_length,
153+
device=device,
154+
dtype=dtype,
155+
)
156+
157+
if do_classifier_free_guidance and negative_prompt_embeds is None:
158+
negative_prompt = negative_prompt or ""
159+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
160+
161+
if prompt is not None and type(prompt) is not type(negative_prompt):
162+
raise TypeError(
163+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}."
164+
)
165+
elif batch_size != len(negative_prompt):
166+
raise ValueError(
167+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
168+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
169+
" the batch size of `prompt`."
170+
)
171+
172+
negative_prompt_embeds = _get_t5_prompt_embeds(
173+
tokenizer=tokenizer,
174+
text_encoder=text_encoder,
175+
prompt=negative_prompt,
176+
num_videos_per_prompt=num_videos_per_prompt,
177+
max_sequence_length=max_sequence_length,
178+
device=device,
179+
dtype=dtype,
180+
)
181+
182+
return prompt_embeds, negative_prompt_embeds
183+
180184

181185
def run(config):
182186
max_logging.log("Wan 2.1 inference script")
@@ -188,17 +192,15 @@ def run(config):
188192
global_batch_size = config.per_device_batch_size * jax.local_device_count()
189193

190194
tokenizer = AutoTokenizer.from_pretrained(
191-
config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype
195+
config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype
192196
)
193197
text_encoder = UMT5EncoderModel.from_pretrained(
194-
config.pretrained_model_name_or_path, subfolder="text_encoder",
198+
config.pretrained_model_name_or_path,
199+
subfolder="text_encoder",
195200
)
196201
s0 = time.perf_counter()
197202
prompt_embeds, negative_prompt_embeds = encode_prompt(
198-
tokenizer=tokenizer,
199-
text_encoder=text_encoder,
200-
prompt=config.prompt,
201-
negative_prompt=config.negative_prompt
203+
tokenizer=tokenizer, text_encoder=text_encoder, prompt=config.prompt, negative_prompt=config.negative_prompt
202204
)
203205
max_logging.log(f"text encoding time: {(time.perf_counter() - s0)}")
204206

@@ -209,20 +211,15 @@ def run(config):
209211
# )
210212
# breakpoint()
211213

212-
pipeline, params = WanPipeline.from_pretrained(
213-
config.pretrained_model_name_or_path,
214-
vae=None,
215-
transformer=None
216-
)
217-
218-
#wan_transformer = WanModel(rngs=nnx.Rngs(config.seed))
219-
214+
pipeline, params = WanPipeline.from_pretrained(config.pretrained_model_name_or_path, vae=None, transformer=None)
220215

216+
# wan_transformer = WanModel(rngs=nnx.Rngs(config.seed))
221217

222218

223219
def main(argv: Sequence[str]) -> None:
224220
pyconfig.initialize(argv)
225221
run(pyconfig.config)
226222

223+
227224
if __name__ == "__main__":
228225
app.run(main)

src/maxdiffusion/image_processor.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,51 +35,52 @@
3535
List[torch.FloatTensor],
3636
]
3737

38+
3839
def is_valid_image(image) -> bool:
39-
r"""
40-
Checks if the input is a valid image.
40+
r"""
41+
Checks if the input is a valid image.
4142
42-
A valid image can be:
43-
- A `PIL.Image.Image`.
44-
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
43+
A valid image can be:
44+
- A `PIL.Image.Image`.
45+
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
4546
46-
Args:
47-
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
48-
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
47+
Args:
48+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
49+
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
4950
50-
Returns:
51-
`bool`:
52-
`True` if the input is a valid image, `False` otherwise.
53-
"""
54-
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
51+
Returns:
52+
`bool`:
53+
`True` if the input is a valid image, `False` otherwise.
54+
"""
55+
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
5556

5657

5758
def is_valid_image_imagelist(images):
58-
r"""
59-
Checks if the input is a valid image or list of images.
59+
r"""
60+
Checks if the input is a valid image or list of images.
6061
61-
The input can be one of the following formats:
62-
- A 4D tensor or numpy array (batch of images).
63-
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
64-
`torch.Tensor`.
65-
- A list of valid images.
62+
The input can be one of the following formats:
63+
- A 4D tensor or numpy array (batch of images).
64+
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
65+
`torch.Tensor`.
66+
- A list of valid images.
6667
67-
Args:
68-
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
69-
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
70-
images.
68+
Args:
69+
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
70+
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
71+
images.
7172
72-
Returns:
73-
`bool`:
74-
`True` if the input is valid, `False` otherwise.
75-
"""
76-
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
77-
return True
78-
elif is_valid_image(images):
79-
return True
80-
elif isinstance(images, list):
81-
return all(is_valid_image(image) for image in images)
82-
return False
73+
Returns:
74+
`bool`:
75+
`True` if the input is valid, `False` otherwise.
76+
"""
77+
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
78+
return True
79+
elif is_valid_image(images):
80+
return True
81+
elif isinstance(images, list):
82+
return all(is_valid_image(image) for image in images)
83+
return False
8384

8485

8586
class VaeImageProcessor(ConfigMixin):

0 commit comments

Comments
 (0)