Skip to content

Commit 33d16f7

Browse files
committed
sharding before call to vae encoder
1 parent 6ecc205 commit 33d16f7

1 file changed

Lines changed: 5 additions & 6 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -529,13 +529,12 @@ def prepare_latents_i2v_base(
529529

530530
vae_dtype = getattr(self.vae, "dtype", jnp.float32)
531531
video_condition = video_condition.astype(vae_dtype)
532-
sharding_spec = P(self.config.mesh_axes[0], None, None, None, None)
533-
video_condition = jax.lax.with_sharding_constraint(
534-
video_condition,
535-
sharding_spec
536-
)
537-
538532
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
533+
sharding_spec = P(self.config.mesh_axes[0], None, None, None, None)
534+
video_condition = jax.lax.with_sharding_constraint(
535+
video_condition,
536+
sharding_spec
537+
)
539538
encoded_output = self.vae.encode(video_condition, self.vae_cache)[0].mode()
540539

541540
# Normalize latents

0 commit comments

Comments
 (0)