@@ -162,15 +162,36 @@ def preprocess_conditions(
162162 Based on https://github.com/huggingface/diffusers/blob/17c0e79dbdf53fb6705e9c09cc1a854b84c39249/src/diffusers/pipelines/wan/pipeline_wan_vace.py#L414
163163 """
164164 if video is not None :
165- raise NotImplementedError ("Video support is not yet implemented." )
165+ base = self .vae_scale_factor_spatial * (
166+ self .transformer .config .patch_size [1 ]
167+ if self .transformer is not None
168+ else self .transformer_2 .config .patch_size [1 ]
169+ )
170+ video_height , video_width = self .video_processor .get_default_height_width (video [0 ])
171+
172+ if video_height * video_width > height * width :
173+ scale = min (width / video_width , height / video_width )
174+ video_height , video_width = int (video_height * scale ), int (video_width * scale )
175+
176+ if video_height % base != 0 or video_width % base != 0 :
177+ video_height = (video_height // base ) * base
178+ video_width = (video_width // base ) * base
179+
180+ assert video_height * video_width <= height * width
181+
182+ video = self .video_processor .preprocess_video (video , video_height , video_width )
183+ video = jnp .array (np .asarray (video ), dtype = dtype )
184+ image_size = (video_height , video_width ) # Use the height/width of video (with possible rescaling)
166185 else :
167186 video = jnp .zeros (
168- (batch_size , num_frames , height , width , 3 ), dtype = dtype
187+ (batch_size , 3 , num_frames , height , width ), dtype = dtype
169188 )
170189 image_size = (height , width ) # Use the height/width provider by user
171190
172191 if mask is not None :
173- raise NotImplementedError ("Mask support is not yet implemented." )
192+ mask = self .video_processor .preprocess_video (mask , image_size [0 ], image_size [1 ])
193+ mask = jnp .array (np .asarray (mask ), dtype = video .dtype )
194+ mask = jnp .clip ((mask + 1 ) / 2 , a_min = 0 , a_max = 1 )
174195 else :
175196 mask = jnp .ones_like (video )
176197
@@ -239,10 +260,10 @@ def preprocess_conditions(
239260
240261 def prepare_masks (
241262 self ,
242- mask : torch . Tensor ,
263+ mask : jax . Array ,
243264 reference_images : Optional [List [torch .Tensor ]] = None ,
244265 ):
245- masks_torch = torch .Tensor (np .array (mask ). transpose ( 0 , 4 , 1 , 2 , 3 ) )
266+ masks_torch = torch .Tensor (np .array (mask ))
246267 mask = masks_torch
247268 if reference_images is None :
248269 # For each batch of video, we set no reference image (as one or more can
@@ -614,8 +635,8 @@ def __call__(
614635 def prepare_video_latents (
615636 self ,
616637 data_sharding : NamedSharding ,
617- video : torch . Tensor ,
618- mask : torch . Tensor ,
638+ video : jax . Array ,
639+ mask : jax . Array ,
619640 reference_images : Optional [List [List [torch .Tensor ]]] = None ,
620641 rngs = None ,
621642 ) -> jax .Array :
0 commit comments