|
31 | 31 | from ...models.wan.transformers.transformer_wan_vace import WanVACEModel |
32 | 32 | from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler |
33 | 33 | from ...models.modeling_flax_pytorch_utils import torch2jax |
34 | | -from .wan_pipeline import WanPipeline, cast_with_exclusion |
| 34 | +from .wan_pipeline import cast_with_exclusion |
| 35 | +from .wan_pipeline_2_1 import WanPipeline2_1 |
35 | 36 | import torch |
36 | 37 | import PIL |
37 | 38 |
|
@@ -125,12 +126,12 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): |
125 | 126 | return wan_transformer |
126 | 127 |
|
127 | 128 |
|
128 | | -class VaceWanPipeline(WanPipeline): |
| 129 | +class VaceWanPipeline2_1(WanPipeline2_1): |
129 | 130 | r"""Pipeline for video generation using Wan + VACE. |
130 | 131 |
|
131 | 132 | Currently it only supports reference image(s) + text to video generation. |
132 | 133 |
|
133 | | - It extends `WanPipeline` to support additional conditioning signals. |
| 134 | + It extends `WanPipeline2_1` to support additional conditioning signals. |
134 | 135 |
|
135 | 136 | tokenizer ([`T5Tokenizer`]): |
136 | 137 | Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), |
@@ -164,8 +165,6 @@ def preprocess_conditions( |
164 | 165 | if video is not None: |
165 | 166 | base = self.vae_scale_factor_spatial * ( |
166 | 167 | self.transformer.config.patch_size[1] |
167 | | - if self.transformer is not None |
168 | | - else self.transformer_2.config.patch_size[1] |
169 | 168 | ) |
170 | 169 | video_height, video_width = self.video_processor.get_default_height_width(video[0]) |
171 | 170 |
|
@@ -280,11 +279,7 @@ def prepare_masks( |
280 | 279 | "Generating with more than one video is not yet supported. This may be supported in the future." |
281 | 280 | ) |
282 | 281 |
|
283 | | - transformer_patch_size = ( |
284 | | - self.transformer.config.patch_size[1] |
285 | | - if self.transformer is not None |
286 | | - else self.transformer_2.config.patch_size[1] |
287 | | - ) |
| 282 | + transformer_patch_size = self.transformer.config.patch_size[1] |
288 | 283 |
|
289 | 284 | mask_list = [] |
290 | 285 | for mask_, reference_images_batch in zip(mask, reference_images): |
@@ -374,11 +369,9 @@ def check_inputs( |
374 | 369 | ): |
375 | 370 | if self.transformer is not None: |
376 | 371 | base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] |
377 | | - elif self.transformer_2 is not None: |
378 | | - base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1] |
379 | 372 | else: |
380 | 373 | raise ValueError( |
381 | | - "`transformer` or `transformer_2` component must be set in order to run inference with this pipeline" |
| 374 | + "`transformer` component must be set in order to run inference with this pipeline" |
382 | 375 | ) |
383 | 376 |
|
384 | 377 | if height % base != 0 or width % base != 0: |
@@ -520,12 +513,7 @@ def __call__( |
520 | 513 | ) |
521 | 514 |
|
522 | 515 | transformer_dtype = self.transformer.proj_out.bias.dtype |
523 | | - |
524 | | - vace_layers = ( |
525 | | - self.transformer.config.vace_layers |
526 | | - if self.transformer is not None |
527 | | - else self.transformer_2.config.vace_layers |
528 | | - ) |
| 516 | + vace_layers = self.transformer.config.vace_layers |
529 | 517 |
|
530 | 518 | if isinstance(conditioning_scale, (int, float)): |
531 | 519 | conditioning_scale = [conditioning_scale] * len(vace_layers) |
|
0 commit comments