Skip to content

Commit f94ee08

Browse files
authored
Update VaceWanPipeline to VaceWanPipeline2_1, fix from_pretrained (#311)
- Rename VaceWanPipeline to VaceWanPipeline2_1. - Make VaceWanPipeline2_1 inherit from WanPipeline2_1. - Remove calls to self.transformer2. This change aligns the class with the WanPipeline naming structure and resolves an initialization bug when using the from_pretrained method.
1 parent e8bdd82 commit f94ee08

1 file changed

Lines changed: 7 additions & 19 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_vace_pipeline.py renamed to src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
from ...models.wan.transformers.transformer_wan_vace import WanVACEModel
3232
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
3333
from ...models.modeling_flax_pytorch_utils import torch2jax
34-
from .wan_pipeline import WanPipeline, cast_with_exclusion
34+
from .wan_pipeline import cast_with_exclusion
35+
from .wan_pipeline_2_1 import WanPipeline2_1
3536
import torch
3637
import PIL
3738

@@ -125,12 +126,12 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
125126
return wan_transformer
126127

127128

128-
class VaceWanPipeline(WanPipeline):
129+
class VaceWanPipeline2_1(WanPipeline2_1):
129130
r"""Pipeline for video generation using Wan + VACE.
130131
131132
Currently it only supports reference image(s) + text to video generation.
132133
133-
It extends `WanPipeline` to support additional conditioning signals.
134+
It extends `WanPipeline2_1` to support additional conditioning signals.
134135
135136
tokenizer ([`T5Tokenizer`]):
136137
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
@@ -164,8 +165,6 @@ def preprocess_conditions(
164165
if video is not None:
165166
base = self.vae_scale_factor_spatial * (
166167
self.transformer.config.patch_size[1]
167-
if self.transformer is not None
168-
else self.transformer_2.config.patch_size[1]
169168
)
170169
video_height, video_width = self.video_processor.get_default_height_width(video[0])
171170

@@ -280,11 +279,7 @@ def prepare_masks(
280279
"Generating with more than one video is not yet supported. This may be supported in the future."
281280
)
282281

283-
transformer_patch_size = (
284-
self.transformer.config.patch_size[1]
285-
if self.transformer is not None
286-
else self.transformer_2.config.patch_size[1]
287-
)
282+
transformer_patch_size = self.transformer.config.patch_size[1]
288283

289284
mask_list = []
290285
for mask_, reference_images_batch in zip(mask, reference_images):
@@ -374,11 +369,9 @@ def check_inputs(
374369
):
375370
if self.transformer is not None:
376371
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
377-
elif self.transformer_2 is not None:
378-
base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
379372
else:
380373
raise ValueError(
381-
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
374+
"`transformer` component must be set in order to run inference with this pipeline"
382375
)
383376

384377
if height % base != 0 or width % base != 0:
@@ -520,12 +513,7 @@ def __call__(
520513
)
521514

522515
transformer_dtype = self.transformer.proj_out.bias.dtype
523-
524-
vace_layers = (
525-
self.transformer.config.vace_layers
526-
if self.transformer is not None
527-
else self.transformer_2.config.vace_layers
528-
)
516+
vace_layers = self.transformer.config.vace_layers
529517

530518
if isinstance(conditioning_scale, (int, float)):
531519
conditioning_scale = [conditioning_scale] * len(vace_layers)

0 commit comments

Comments
 (0)