Skip to content

Commit 9276f26

Browse files
wan pipeline wip
1 parent ae7a538 commit 9276f26

10 files changed

Lines changed: 422 additions & 4 deletions

File tree

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ MaxDiffusion supports
5353
- [Training](#training)
5454
- [Dreambooth](#dreambooth)
5555
- [Inference](#inference)
56+
- [Wan 2.1](#wan)
5657
- [Flux](#flux)
5758
- [Fused Attention for GPU:](#fused-attention-for-gpu)
5859
- [Hyper SDXL LoRA](#hyper-sdxl-lora)
@@ -171,6 +172,13 @@ To generate images, run the following command:
171172
```bash
172173
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run"
173174
```
175+
176+
## Wan
177+
178+
```bash
179+
python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_t2v.yml run_name="wan-test" output_dir="gs://jfacevedo-maxdiffusion" jax_cache_dir="/tmp/"
180+
```
181+
174182
## Flux
175183

176184
First make sure you have permissions to access the Flux repos in Huggingface.

src/maxdiffusion/configs/base_wan_t2v.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ gcs_metrics: False
2323
save_config_to_gcs: False
2424
log_period: 100
2525

26-
pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
27-
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
28-
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
26+
pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'
2927

3028
# Flux params
3129
flux_name: "flux-dev"

src/maxdiffusion/generate_wan.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,20 @@
1919
from absl import app
2020
from maxdiffusion import pyconfig, max_logging
2121
from maxdiffusion.models.wan.transformers.transformer_flux_wan_nnx import WanModel
22+
from maxdiffusion.pipelines.wan.pipeline_wan import WanPipeline
2223

2324
def run(config):
2425
max_logging.log("Wan 2.1 inference script")
2526

26-
wan_transformer = WanModel(rngs=nnx.Rngs(config.seed))
27+
pipeline, params = WanPipeline.from_pretrained(
28+
config.pretrained_model_name_or_path,
29+
vae=None,
30+
transformer=None
31+
)
32+
breakpoint()
33+
34+
#wan_transformer = WanModel(rngs=nnx.Rngs(config.seed))
35+
2736

2837
def main(argv: Sequence[str]) -> None:
2938
pyconfig.initialize(argv)

src/maxdiffusion/image_processor.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,52 @@
3535
List[torch.FloatTensor],
3636
]
3737

38+
def is_valid_image(image) -> bool:
39+
r"""
40+
Checks if the input is a valid image.
41+
42+
A valid image can be:
43+
- A `PIL.Image.Image`.
44+
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
45+
46+
Args:
47+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
48+
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
49+
50+
Returns:
51+
`bool`:
52+
`True` if the input is a valid image, `False` otherwise.
53+
"""
54+
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
55+
56+
57+
def is_valid_image_imagelist(images):
58+
r"""
59+
Checks if the input is a valid image or list of images.
60+
61+
The input can be one of the following formats:
62+
- A 4D tensor or numpy array (batch of images).
63+
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
64+
`torch.Tensor`.
65+
- A list of valid images.
66+
67+
Args:
68+
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
69+
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
70+
images.
71+
72+
Returns:
73+
`bool`:
74+
`True` if the input is valid, `False` otherwise.
75+
"""
76+
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
77+
return True
78+
elif is_valid_image(images):
79+
return True
80+
elif isinstance(images, list):
81+
return all(is_valid_image(image) for image in images)
82+
return False
83+
3884

3985
class VaeImageProcessor(ConfigMixin):
4086
"""
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Tuple, List
18+
from flax import nnx
19+
from ...configuration_utils import ConfigMixin, flax_register_to_config
20+
from ..modeling_flax_utils import FlaxModelMixin
21+
22+
class WanEncoder3d(nnx.Module):
23+
pass
24+
25+
class WanCausalConv3d(nnx.Module):
26+
pass
27+
28+
class WanDecoder3d(nnx.Module):
29+
pass
30+
31+
class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin):
32+
def __init__(
33+
self,
34+
base_dim: int = 96,
35+
z_dim: int = 16,
36+
dim_mult: Tuple[int] = [1,2,4,4],
37+
num_res_blocks: int = 2,
38+
attn_scales: List[float] = [],
39+
temporal_downsample: List[bool] = [False, True, True],
40+
dropout: float = 0.0,
41+
latents_mean: List[float] = [
42+
-0.7571,
43+
-0.7089,
44+
-0.9113,
45+
0.1075,
46+
-0.1745,
47+
0.9653,
48+
-0.1517,
49+
1.5508,
50+
0.4134,
51+
-0.0715,
52+
0.5517,
53+
-0.3632,
54+
-0.1922,
55+
-0.9497,
56+
0.2503,
57+
-0.2921,
58+
],
59+
latents_std: List[float] = [
60+
2.8184,
61+
1.4541,
62+
2.3275,
63+
2.6558,
64+
1.2196,
65+
1.7708,
66+
2.6052,
67+
2.0743,
68+
3.2687,
69+
2.1526,
70+
2.8652,
71+
1.5579,
72+
1.6382,
73+
1.1253,
74+
2.8251,
75+
1.9160,
76+
],
77+
):
78+
self.z_dim = z_dim
79+
self.temporal_downsample = temporal_downsample
80+
self.temporal_upsample = temporal_downsample[::-1]
81+
82+
self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1)
83+
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
84+
85+
self.decoder = WanDecoder3d(
86+
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout
87+
)

src/maxdiffusion/pipelines/wan/__init__.py

Whitespace-only changes.
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Union, List
18+
from transformers import AutoTokenizer, UMT5EncoderModel
19+
import torch
20+
from ...models.wan.transformers.transformer_flux_wan_nnx import WanModel
21+
from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan
22+
from ..pipeline_flax_utils import FlaxDiffusionPipeline
23+
from ...video_processor import VideoProcessor
24+
from ...schedulers import FlowMatchEulerDiscreteScheduler
25+
26+
class WanPipeline(FlaxDiffusionPipeline):
27+
28+
def __init__(
29+
self,
30+
tokenizer: AutoTokenizer,
31+
text_encoder: UMT5EncoderModel,
32+
transformer: WanModel,
33+
vae: AutoencoderKLWan,
34+
scheduler: FlowMatchEulerDiscreteScheduler,
35+
):
36+
super().__init__()
37+
38+
self.register_modules(
39+
vae=vae,
40+
text_encoder=text_encoder,
41+
tokenizer=tokenizer,
42+
transformer=transformer,
43+
scheduler=scheduler
44+
)
45+
46+
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
47+
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
48+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
49+
50+
51+
def _get_t5_prompt_embds(
52+
self,
53+
prompt: Union[str, List[str]] = None,
54+
num_videos_per_prompt: int = 1,
55+
max_sequence_length: int = 226,
56+
):
57+
prompt = [prompt] if isinstance(prompt, str) else prompt
58+
batch_size = len(prompt)
59+
60+
text_inputs = self.tokenizer(
61+
prompt,
62+
padding="max_length",
63+
max_length=max_sequence_length,
64+
truncation=True,
65+
add_special_tokens=True,
66+
return_attention_mask=True,
67+
return_tensors="pt",
68+
)
69+
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
70+
seq_lens = mask.gt(0).sum(dim=1).long()
71+
72+
prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
73+
# prompt_embeds = prompt_embeds.to(dtype=dtype)
74+
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
75+
prompt_embeds = torch.stack(
76+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
77+
)
78+
79+
# duplicate text embeddings for each generation per prompt, using mps friendly method
80+
_, seq_len, _ = prompt_embeds.shape
81+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
82+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
83+
84+
return prompt_embeds

src/maxdiffusion/schedulers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
_import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"]
4949
_import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"]
5050
_import_structure["scheduling_sde_ve_flax"] = ["FlaxScoreSdeVeScheduler"]
51+
_import_structure["scheduling_flow_match_euler_discrete_flax"] = ["FlowMatchEulerDiscreteScheduler"]
5152
_import_structure["scheduling_utils_flax"] = [
5253
"FlaxKarrasDiffusionSchedulers",
5354
"FlaxSchedulerMixin",
@@ -73,6 +74,7 @@
7374
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
7475
from .scheduling_pndm_flax import FlaxPNDMScheduler
7576
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
77+
from .scheduling_flow_match_euler_discrete_flax import FlowMatchEulerDiscreteScheduler
7678
from .scheduling_utils_flax import (
7779
FlaxKarrasDiffusionSchedulers,
7880
FlaxSchedulerMixin,
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from dataclasses import dataclass
18+
from typing import Optional, Tuple, Union
19+
20+
import flax
21+
import jax.numpy as jnp
22+
23+
from ..configuration_utils import ConfigMixin, register_to_config
24+
from .scheduling_utils_flax import (
25+
CommonSchedulerState,
26+
# FlaxKarrasDiffusionSchedulers,
27+
FlaxSchedulerMixin,
28+
FlaxSchedulerOutput,
29+
broadcast_to_shape_from_left,
30+
)
31+
32+
@flax.struct.dataclass
33+
class FlowMatchEulerDiscreteSchedulerState:
34+
common: CommonSchedulerState
35+
36+
@dataclass
37+
class FlowMatchEulerDiscreteSchedulerOutput(FlaxSchedulerOutput):
38+
state: FlowMatchEulerDiscreteSchedulerState
39+
40+
class FlowMatchEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
41+
# _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
42+
43+
dtype: jnp.dtype
44+
45+
@property
46+
def has_state(self):
47+
return True
48+
49+
@register_to_config
50+
def __init__(
51+
self,
52+
num_train_timesteps: int = 1000,
53+
shift: float = 1.0,
54+
use_dynamic_shifting: bool = False,
55+
base_shift: Optional[float] = 0.5,
56+
max_shift: Optional[float] = 1.15,
57+
base_image_seq_len: Optional[int] = 256,
58+
max_image_seq_len: Optional[int] = 4096,
59+
invert_sigmas: bool = False,
60+
shift_terminal: Optional[float] = None,
61+
use_karras_sigmas: Optional[bool] = False,
62+
use_exponential_sigmas: Optional[bool] = False,
63+
use_beta_sigmas: Optional[bool] = False,
64+
time_shift_type: str = "exponential",
65+
dtype: jnp.dtype = jnp.float32
66+
):
67+
self.dtype = dtype
68+
69+
def create_state(self, common: Optional[CommonSchedulerState] = None) -> FlowMatchEulerDiscreteSchedulerState:
70+
if common is None:
71+
common = CommonSchedulerState.create(self)

0 commit comments

Comments
 (0)