Skip to content

Commit 73259e3

Browse files
martinarroyoninatu
andcommitted
Adds the logic to condition on videos and masks
Co-authored-by: ninatu <ninashv@google.com>
1 parent db2a559 commit 73259e3

1 file changed

Lines changed: 28 additions & 7 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_vace_pipeline.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,36 @@ def preprocess_conditions(
162162
Based on https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/pipelines/wan/pipeline_wan_vace.py#L414
163163
"""
164164
if video is not None:
165-
raise NotImplementedError("Video support is not yet implemented.")
165+
base = self.vae_scale_factor_spatial * (
166+
self.transformer.config.patch_size[1]
167+
if self.transformer is not None
168+
else self.transformer_2.config.patch_size[1]
169+
)
170+
video_height, video_width = self.video_processor.get_default_height_width(video[0])
171+
172+
if video_height * video_width > height * width:
173+
scale = min(width / video_width, height / video_width)
174+
video_height, video_width = int(video_height * scale), int(video_width * scale)
175+
176+
if video_height % base != 0 or video_width % base != 0:
177+
video_height = (video_height // base) * base
178+
video_width = (video_width // base) * base
179+
180+
assert video_height * video_width <= height * width
181+
182+
video = self.video_processor.preprocess_video(video, video_height, video_width)
183+
video = jnp.array(np.asarray(video), dtype=dtype)
184+
image_size = (video_height, video_width) # Use the height/width of video (with possible rescaling)
166185
else:
167186
video = jnp.zeros(
168-
(batch_size, num_frames, height, width, 3), dtype=dtype
187+
(batch_size, 3, num_frames, height, width), dtype=dtype
169188
)
170189
image_size = (height, width) # Use the height/width provider by user
171190

172191
if mask is not None:
173-
raise NotImplementedError("Mask support is not yet implemented.")
192+
mask = self.video_processor.preprocess_video(mask, image_size[0], image_size[1])
193+
mask = jnp.array(np.asarray(mask), dtype=video.dtype)
194+
mask = jnp.clip((mask + 1) / 2, a_min=0, a_max=1)
174195
else:
175196
mask = jnp.ones_like(video)
176197

@@ -239,10 +260,10 @@ def preprocess_conditions(
239260

240261
def prepare_masks(
241262
self,
242-
mask: torch.Tensor,
263+
mask: jax.Array,
243264
reference_images: Optional[List[torch.Tensor]] = None,
244265
):
245-
masks_torch = torch.Tensor(np.array(mask).transpose(0, 4, 1, 2, 3))
266+
masks_torch = torch.Tensor(np.array(mask))
246267
mask = masks_torch
247268
if reference_images is None:
248269
# For each batch of video, we set no reference image (as one or more can
@@ -614,8 +635,8 @@ def __call__(
614635
def prepare_video_latents(
615636
self,
616637
data_sharding: NamedSharding,
617-
video: torch.Tensor,
618-
mask: torch.Tensor,
638+
video: jax.Array,
639+
mask: jax.Array,
619640
reference_images: Optional[List[List[torch.Tensor]]] = None,
620641
rngs=None,
621642
) -> jax.Array:

0 commit comments

Comments
 (0)