Skip to content

Commit 6ecc205

Browse files
committed
sharding before call to vae encoder
1 parent e8bdd82 commit 6ecc205

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,11 @@ def prepare_latents_i2v_base(
529529

530530
vae_dtype = getattr(self.vae, "dtype", jnp.float32)
531531
video_condition = video_condition.astype(vae_dtype)
532+
sharding_spec = P(self.config.mesh_axes[0], None, None, None, None)
533+
video_condition = jax.lax.with_sharding_constraint(
534+
video_condition,
535+
sharding_spec
536+
)
532537

533538
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
534539
encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode()

0 commit comments

Comments
 (0)