Skip to content

Commit 7f50b2b

Browse files
committed
modified wan pipelines prepare_latents
1 parent 06cb1ae commit 7f50b2b

2 files changed

Lines changed: 10 additions & 5 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,14 @@ def prepare_latents(
123123
first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2)
124124
mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2)
125125
mask_lat_size = mask_lat_size.reshape(
126-
batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width
126+
batch_size,
127+
1,
128+
num_latent_frames,
129+
self.vae_scale_factor_temporal,
130+
latent_height,
131+
latent_width
127132
)
128-
mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 3, 4, 1))
133+
mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1)
129134
condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1)
130135

131136
return latents, condition, None

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ def prepare_latents(
126126
first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2)
127127
mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2)
128128
mask_lat_size = mask_lat_size.reshape(
129-
batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width
129+
batch_size, 1, num_latent_frames, self.vae_scale_factor_temporal, latent_height, latent_width
130130
)
131-
mask_lat_size = jnp.swapaxes(mask_lat_size, 1, 2)
132-
condition = jnp.concatenate([mask_lat_size, latent_condition], axis=1)
131+
mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1)
132+
condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1)
133133

134134
return latents, condition, None
135135

0 commit comments

Comments
 (0)