1919from typing import List , Union , Optional , Tuple
2020from ...pyconfig import HyperParameters
2121from functools import partial
22- import numpy as np
2322from flax import nnx
2423from flax .linen import partitioning as nn_partitioning
2524import jax
@@ -88,7 +87,7 @@ def prepare_latents(
8887 last_image : Optional [jax .Array ] = None ,
8988 num_videos_per_prompt : int = 1 ,
9089 ) -> Tuple [jax .Array , jax .Array , Optional [jax .Array ]]:
91-
90+
9291 if hasattr (image , "detach" ):
9392 image = image .detach ().cpu ().numpy ()
9493 image = jnp .array (image )
@@ -97,12 +96,12 @@ def prepare_latents(
9796 if hasattr (last_image , "detach" ):
9897 last_image = last_image .detach ().cpu ().numpy ()
9998 last_image = jnp .array (last_image )
100-
99+
101100 if num_videos_per_prompt > 1 :
102101 image = jnp .repeat (image , num_videos_per_prompt , axis = 0 )
103102 if last_image is not None :
104103 last_image = jnp .repeat (last_image , num_videos_per_prompt , axis = 0 )
105-
104+
106105 num_channels_latents = self .vae .z_dim
107106 num_latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
108107 latent_height = height // self .vae_scale_factor_spatial
@@ -119,16 +118,16 @@ def prepare_latents(
119118 if last_image is None :
120119 mask_lat_size = mask_lat_size .at [:, :, 1 :, :, :].set (0 )
121120 else :
122- mask_lat_size = mask_lat_size .at [:, :, 1 :- 1 , :, :].set (0 )
121+ mask_lat_size = mask_lat_size .at [:, :, 1 :- 1 , :, :].set (0 )
123122 first_frame_mask = mask_lat_size [:, :, 0 :1 ]
124123 first_frame_mask = jnp .repeat (first_frame_mask , self .vae_scale_factor_temporal , axis = 2 )
125124 mask_lat_size = jnp .concatenate ([first_frame_mask , mask_lat_size [:, :, 1 :]], axis = 2 )
126125 mask_lat_size = mask_lat_size .reshape (
127- batch_size ,
126+ batch_size ,
128127 1 ,
129- num_latent_frames ,
130- self .vae_scale_factor_temporal ,
131- latent_height ,
128+ num_latent_frames ,
129+ self .vae_scale_factor_temporal ,
130+ latent_height ,
132131 latent_width
133132 )
134133 mask_lat_size = jnp .transpose (mask_lat_size , (0 , 2 , 4 , 5 , 3 , 1 )).squeeze (- 1 )
@@ -210,7 +209,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
210209 scheduler_state = self .scheduler .set_timesteps (
211210 self .scheduler_state , num_inference_steps = num_inference_steps , shape = latents .shape
212211 )
213-
212+
214213 graphdef , state , rest_of_state = nnx .split (self .transformer , nnx .Param , ...)
215214 data_sharding = NamedSharding (self .mesh , P ())
216215 if self .config .global_batch_size_to_train_on // self .config .per_device_batch_size == 0 :
@@ -234,7 +233,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
234233 scheduler = self .scheduler ,
235234 )
236235
237-
236+
238237 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
239238 latents = p_run_inference (
240239 latents = latents ,
@@ -246,7 +245,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
246245 )
247246 latents = jnp .transpose (latents , (0 , 4 , 1 , 2 , 3 ))
248247 latents = self ._denormalize_latents (latents )
249-
248+
250249 if output_type == "latent" :
251250 return latents
252251 return self ._decode_latents_to_video (latents )
@@ -287,5 +286,5 @@ def run_inference_2_1_i2v(
287286 encoder_hidden_states_image = image_embeds ,
288287 )
289288 noise_pred = jnp .transpose (noise_pred , (0 , 2 , 3 , 4 , 1 ))
290- latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents , return_dict = False )
291- return latents
289+ latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents , return_dict = False )
290+ return latents
0 commit comments