Skip to content

Commit 120ceb3

Browse files
wip - vae
1 parent ae7a538 commit 120ceb3

11 files changed

Lines changed: 737 additions & 5 deletions

File tree

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ MaxDiffusion supports
5353
- [Training](#training)
5454
- [Dreambooth](#dreambooth)
5555
- [Inference](#inference)
56+
- [Wan 2.1](#wan)
5657
- [Flux](#flux)
5758
- [Fused Attention for GPU:](#fused-attention-for-gpu)
5859
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
@@ -171,6 +172,13 @@ To generate images, run the following command:
171172
```bash
172173
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
173174
```
175+
176+
## Wan
177+
178+
```bash
179+
python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_t2v.yml run_name="wan-test" output_dir="gs://jfacevedo-maxdiffusion" jax_cache_dir="/tmp/"
180+
```
181+
174182
## Flux
175183

176184
First make sure you have permissions to access the Flux repos in Huggingface.

src/maxdiffusion/configs/base_wan_t2v.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ gcs_metrics: False
2323
save_config_to_gcs: False
2424
log_period: 100
2525

26-
pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
27-
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
28-
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
26+
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
2927

3028
# Flux params
3129
flux_name: "flux-dev"

src/maxdiffusion/generate_wan.py

Lines changed: 189 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,205 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16-
17-
from typing import Callable, List, Union, Sequence
16+
import html
17+
from typing import Callable, List, Union, Sequence, Optional
18+
import time
19+
import torch
20+
import ftfy
21+
import regex as re
22+
import jax
23+
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
1824
from flax import nnx
1925
from absl import app
26+
from transformers import AutoTokenizer, UMT5EncoderModel
2027
from maxdiffusion import pyconfig, max_logging
2128
from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel
29+
from maxdiffusion.pipelines.wan.pipeline_wan import WanPipeline
30+
31+
from maxdiffusion.max_utils import (
32+
device_put_replicated,
33+
get_memory_allocations,
34+
create_device_mesh,
35+
get_flash_block_sizes,
36+
get_precision,
37+
setup_initial_state,
38+
)
39+
40+
def basic_clean(text):
41+
text = ftfy.fix_text(text)
42+
text = html.unescape(html.unescape(text))
43+
return text.strip()
44+
45+
46+
def whitespace_clean(text):
47+
text = re.sub(r"\s+", " ", text)
48+
text = text.strip()
49+
return text
50+
51+
52+
def prompt_clean(text):
53+
text = whitespace_clean(basic_clean(text))
54+
return text
55+
56+
def _get_t5_prompt_embeds(
57+
tokenizer: AutoTokenizer,
58+
text_encoder: UMT5EncoderModel,
59+
prompt: Union[str, List[str]] = None,
60+
num_videos_per_prompt: int = 1,
61+
max_sequence_length: int = 226,
62+
device: Optional[torch.device] = None,
63+
dtype: Optional[torch.dtype] = None,
64+
):
65+
66+
prompt = [prompt] if isinstance(prompt, str) else prompt
67+
prompt = [prompt_clean(u) for u in prompt]
68+
batch_size = len(prompt)
69+
70+
text_inputs = tokenizer(
71+
prompt,
72+
padding="max_length",
73+
max_length=max_sequence_length,
74+
truncation=True,
75+
add_special_tokens=True,
76+
return_attention_mask=True,
77+
return_tensors="pt",
78+
)
79+
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
80+
seq_lens = mask.gt(0).sum(dim=1).long()
81+
82+
prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
83+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
84+
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
85+
prompt_embeds = torch.stack(
86+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
87+
)
88+
89+
# duplicate text embeddings for each generation per prompt, using mps friendly method
90+
_, seq_len, _ = prompt_embeds.shape
91+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
92+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
93+
94+
return prompt_embeds
95+
96+
def encode_prompt(
97+
tokenizer: AutoTokenizer,
98+
text_encoder: UMT5EncoderModel,
99+
prompt: Union[str, List[str]],
100+
negative_prompt: Optional[Union[str, List[str]]] = None,
101+
do_classifier_free_guidance: bool = True,
102+
num_videos_per_prompt: int = 1,
103+
prompt_embeds: Optional[torch.Tensor] = None,
104+
negative_prompt_embeds: Optional[torch.Tensor] = None,
105+
max_sequence_length: int = 226,
106+
device: Optional[torch.device] = None,
107+
dtype: Optional[torch.dtype] = None,
108+
):
109+
r"""
110+
Encodes the prompt into text encoder hidden states.
111+
112+
Args:
113+
prompt (`str` or `List[str]`, *optional*):
114+
prompt to be encoded
115+
negative_prompt (`str` or `List[str]`, *optional*):
116+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
117+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
118+
less than `1`).
119+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
120+
Whether to use classifier free guidance or not.
121+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
122+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
123+
prompt_embeds (`torch.Tensor`, *optional*):
124+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
125+
provided, text embeddings will be generated from `prompt` input argument.
126+
negative_prompt_embeds (`torch.Tensor`, *optional*):
127+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
128+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
129+
argument.
130+
device: (`torch.device`, *optional*):
131+
torch device
132+
dtype: (`torch.dtype`, *optional*):
133+
torch dtype
134+
"""
135+
136+
prompt = [prompt] if isinstance(prompt, str) else prompt
137+
if prompt is not None:
138+
batch_size = len(prompt)
139+
else:
140+
batch_size = prompt_embeds.shape[0]
141+
142+
if prompt_embeds is None:
143+
prompt_embeds = _get_t5_prompt_embeds(
144+
tokenizer=tokenizer,
145+
text_encoder=text_encoder,
146+
prompt=prompt,
147+
num_videos_per_prompt=num_videos_per_prompt,
148+
max_sequence_length=max_sequence_length,
149+
device=device,
150+
dtype=dtype,
151+
)
152+
153+
if do_classifier_free_guidance and negative_prompt_embeds is None:
154+
negative_prompt = negative_prompt or ""
155+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
156+
157+
if prompt is not None and type(prompt) is not type(negative_prompt):
158+
raise TypeError(
159+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
160+
f" {type(prompt)}."
161+
)
162+
elif batch_size != len(negative_prompt):
163+
raise ValueError(
164+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
165+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
166+
" the batch size of `prompt`."
167+
)
168+
169+
negative_prompt_embeds = _get_t5_prompt_embeds(
170+
tokenizer=tokenizer,
171+
text_encoder=text_encoder,
172+
prompt=negative_prompt,
173+
num_videos_per_prompt=num_videos_per_prompt,
174+
max_sequence_length=max_sequence_length,
175+
device=device,
176+
dtype=dtype,
177+
)
178+
179+
return prompt_embeds, negative_prompt_embeds
22180

23181
def run(config):
24182
max_logging.log("Wan 2.1 inference script")
25183

184+
rng = jax.random.key(config.seed)
185+
devices_array = create_device_mesh(config)
186+
mesh = Mesh(devices_array, config.mesh_axes)
187+
188+
global_batch_size = config.per_device_batch_size * jax.local_device_count()
189+
190+
tokenizer = AutoTokenizer.from_pretrained(
191+
config.pretrained_model_name_or_path, subfolder="tokenizer", dtype=config.weights_dtype
192+
)
193+
text_encoder = UMT5EncoderModel.from_pretrained(
194+
config.pretrained_model_name_or_path, subfolder="text_encoder",
195+
)
196+
s0 = time.perf_counter()
197+
prompt_embeds, negative_prompt_embeds = encode_prompt(
198+
tokenizer=tokenizer,
199+
text_encoder=text_encoder,
200+
prompt=config.prompt,
201+
negative_prompt=config.negative_prompt
202+
)
203+
max_logging.log(f"text encoding time: {(time.perf_counter() - s0)}")
204+
205+
# pipeline, params = WanPipeline.from_pretrained(
206+
# config.pretrained_model_name_or_path,
207+
# #vae=None,
208+
# #transformer=None
209+
# )
210+
# breakpoint()
211+
26212
wan_transformer = WanModel(rngs=nnx.Rngs(config.seed))
27213

214+
28215
def main(argv: Sequence[str]) -> None:
29216
pyconfig.initialize(argv)
30217
run(pyconfig.config)

src/maxdiffusion/image_processor.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,52 @@
3535
List[torch.FloatTensor],
3636
]
3737

38+
def is_valid_image(image) -> bool:
39+
r"""
40+
Checks if the input is a valid image.
41+
42+
A valid image can be:
43+
- A `PIL.Image.Image`.
44+
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
45+
46+
Args:
47+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
48+
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
49+
50+
Returns:
51+
`bool`:
52+
`True` if the input is a valid image, `False` otherwise.
53+
"""
54+
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
55+
56+
57+
def is_valid_image_imagelist(images):
58+
r"""
59+
Checks if the input is a valid image or list of images.
60+
61+
The input can be one of the following formats:
62+
- A 4D tensor or numpy array (batch of images).
63+
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
64+
`torch.Tensor`.
65+
- A list of valid images.
66+
67+
Args:
68+
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
69+
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
70+
images.
71+
72+
Returns:
73+
`bool`:
74+
`True` if the input is valid, `False` otherwise.
75+
"""
76+
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
77+
return True
78+
elif is_valid_image(images):
79+
return True
80+
elif isinstance(images, list):
81+
return all(is_valid_image(image) for image in images)
82+
return False
83+
3884

3985
class VaeImageProcessor(ConfigMixin):
4086
"""

0 commit comments

Comments
 (0)