Skip to content

Commit 635e2dd

Browse files
committed
added changes for i2v 2.1
1 parent e075de4 commit 635e2dd

2 files changed

Lines changed: 18 additions & 5 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def prepare_latents(
8585
rng: jax.Array,
8686
latents: Optional[jax.Array] = None,
8787
last_image: Optional[jax.Array] = None,
88+
num_videos_per_prompt: int = 1,
8889
) -> Tuple[jax.Array, jax.Array, Optional[jax.Array]]:
8990

9091
if hasattr(image, "detach"):
@@ -96,12 +97,17 @@ def prepare_latents(
9697
last_image = last_image.detach().cpu().numpy()
9798
last_image = jnp.array(last_image)
9899

100+
if num_videos_per_prompt > 1:
101+
image = jnp.repeat(image, num_videos_per_prompt, axis=0)
102+
if last_image is not None:
103+
last_image = jnp.repeat(last_image, num_videos_per_prompt, axis=0)
104+
99105
num_channels_latents = self.vae.z_dim
100106
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
101107
latent_height = height // self.vae_scale_factor_spatial
102108
latent_width = width // self.vae_scale_factor_spatial
103109

104-
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
110+
shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents)
105111

106112
if latents is None:
107113
latents = randn_tensor(shape, rng, self.config, dtype)
@@ -119,7 +125,6 @@ def prepare_latents(
119125
mask_lat_size = mask_lat_size.reshape(
120126
batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width
121127
)
122-
mask_lat_size = jnp.swapaxes(mask_lat_size, 1, 2)
123128
mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 3, 4, 1))
124129
condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1)
125130

@@ -146,13 +151,20 @@ def __call__(
146151
output_type: Optional[str] = "np",
147152
rng: Optional[jax.Array] = None,
148153
):
154+
155+
if num_videos_per_prompt == 1:
156+
n_devices = jax.device_count()
157+
if n_devices > 1:
158+
num_videos_per_prompt = n_devices
159+
149160
height = height or self.config.height
150161
width = width or self.config.width
151162
num_frames = num_frames or self.config.num_frames
152163

153164
prompt_embeds, negative_prompt_embeds, image_embeds, effective_batch_size = self._prepare_model_inputs_i2v(
154165
prompt, image, negative_prompt, num_videos_per_prompt, max_sequence_length,
155-
prompt_embeds, negative_prompt_embeds, image_embeds, last_image
166+
prompt_embeds, negative_prompt_embeds, image_embeds, last_image,
167+
num_videos_per_prompt=num_videos_per_prompt,
156168
)
157169

158170
image_tensor = self.video_processor.preprocess(image, height=height, width=width)
@@ -174,6 +186,7 @@ def __call__(
174186
rng=latents_rng,
175187
latents=latents,
176188
last_image=last_image_tensor,
189+
num_videos_per_prompt=num_videos_per_prompt,
177190
)
178191

179192
scheduler_state = self.scheduler.set_timesteps(
@@ -254,7 +267,7 @@ def loop_body(step, vals):
254267
if do_classifier_free_guidance:
255268
latents_input = jnp.concatenate([latents, latents], axis=0)
256269

257-
latent_model_input = jnp.concatenate([latents_input, condition], axis=1)
270+
latent_model_input = jnp.concatenate([latents_input, condition], axis=-1)
258271
timestep = jnp.broadcast_to(t, latents.shape[0])
259272

260273

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def prepare_latents(
100100
latent_height = height // self.vae_scale_factor_spatial
101101
latent_width = width // self.vae_scale_factor_spatial
102102

103-
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
103+
shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents)
104104

105105
if latents is None:
106106
latents = randn_tensor(shape, rng, self.config, dtype)

0 commit comments

Comments
 (0)